From a002797d527c0846dfc9de1dc4c13009758bd02f Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 21 Dec 2023 17:59:15 -0800 Subject: [PATCH 01/37] A temporary fix (#254) --- mlx/backend/metal/device.cpp | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 68662763a..eda54ea89 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -108,6 +108,16 @@ MTL::Library* load_library( } } +struct PoolHolder { + PoolHolder() { + p = NS::AutoreleasePool::alloc()->init(); + } + ~PoolHolder() { + p->release(); + } + NS::AutoreleasePool* p; +}; + } // namespace Device::Device() @@ -125,6 +135,12 @@ Device::~Device() { for (auto& l : library_map_) { l.second->release(); } + for (auto& b : buffer_map_) { + b.second.second->release(); + } + for (auto& e : encoder_map_) { + e.second->release(); + } device_->release(); pool_->release(); } @@ -282,9 +298,8 @@ Device& device(mlx::core::Device) { } NS::AutoreleasePool*& thread_autorelease_pool() { - static thread_local NS::AutoreleasePool* p = - NS::AutoreleasePool::alloc()->init(); - return p; + static thread_local PoolHolder pool{}; + return pool.p; } void new_stream(Stream stream) { From 2118c3dbfa509304cba234a1df02c643a4a85ca6 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 21 Dec 2023 18:18:41 -0800 Subject: [PATCH 02/37] fix (#255) --- mlx/backend/metal/device.cpp | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index eda54ea89..6b6158f29 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -17,6 +17,8 @@ namespace fs = std::filesystem; namespace mlx::core::metal { +static Device metal_device_; + namespace { // TODO nicer way to set this or possibly expose as an environment variable @@ -108,16 +110,6 @@ MTL::Library* load_library( } } -struct PoolHolder { - PoolHolder() { - p = NS::AutoreleasePool::alloc()->init(); - } - ~PoolHolder() { - p->release(); - } - NS::AutoreleasePool* p; -}; - } // namespace Device::Device() @@ -293,13 +285,13 @@ MTL::ComputePipelineState* Device::get_kernel( } Device& device(mlx::core::Device) { - static Device metal_device_; return metal_device_; } NS::AutoreleasePool*& thread_autorelease_pool() { - static thread_local PoolHolder pool{}; - return pool.p; + static thread_local NS::AutoreleasePool* p = + NS::AutoreleasePool::alloc()->init(); + return p; } void new_stream(Stream stream) { From 8385f93cea75cf07f79202f8d364a6d59b554ae2 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Thu, 21 Dec 2023 18:33:14 -0800 Subject: [PATCH 03/37] Bumping the version (#256) --- CMakeLists.txt | 4 ++-- docs/src/conf.py | 4 ++-- setup.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 2ea908981..70293ebba 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -18,7 +18,7 @@ option(MLX_BUILD_METAL "Build metal backend" ON) option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF) if(NOT MLX_VERSION) - set(MLX_VERSION 0.0.3) + set(MLX_VERSION 0.0.6) endif() # --------------------- Processor tests ------------------------- @@ -221,4 +221,4 @@ install( install( DIRECTORY ${CMAKE_MODULE_PATH}/ DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR} -) \ No newline at end of file +) diff --git a/docs/src/conf.py b/docs/src/conf.py index a5fbf5b16..d38d3424f 100644 --- a/docs/src/conf.py +++ b/docs/src/conf.py @@ -10,8 +10,8 @@ import subprocess project = "MLX" copyright = "2023, MLX Contributors" author = "MLX Contributors" -version = "0.0.5" -release = "0.0.5" +version = "0.0.6" +release = "0.0.6" # -- General configuration --------------------------------------------------- diff --git a/setup.py b/setup.py index 4711a87c9..f5d04e959 100644 --- a/setup.py +++ b/setup.py @@ -165,7 +165,7 @@ if __name__ == "__main__": setup( name="mlx", - version=get_version("0.0.5"), + version=get_version("0.0.6"), author="MLX Contributors", author_email="mlx@group.apple.com", description="A framework for machine learning on Apple silicon.", From e8deca84e0c225b0356844a69f02a6121701ce23 Mon Sep 17 00:00:00 2001 From: Justin Deschenaux <33008801+deschena@users.noreply.github.com> Date: Fri, 22 Dec 2023 17:02:29 +0100 Subject: [PATCH 04/37] Add dropout2d (#250) --- ACKNOWLEDGMENTS.md | 2 +- docs/src/python/nn/layers.rst | 3 ++ python/mlx/nn/layers/__init__.py | 2 +- python/mlx/nn/layers/dropout.py | 59 +++++++++++++++++++++++++++++++- 4 files changed, 63 insertions(+), 3 deletions(-) diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index c9969f8d6..9e3b27532 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -8,7 +8,7 @@ with a short description of your contribution(s) below. For example: MLX was developed with contributions from the following individuals: - Juarez Bochi: Fixed bug in cross attention. -- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, linear and logistic regression python example. +- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example. # Third-Party Software diff --git a/docs/src/python/nn/layers.rst b/docs/src/python/nn/layers.rst index fab3ff785..4c7d4aa79 100644 --- a/docs/src/python/nn/layers.rst +++ b/docs/src/python/nn/layers.rst @@ -27,3 +27,6 @@ Layers MultiHeadAttention Sequential QuantizedLinear + Dropout + Dropout2d + diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index 3f03064bf..d54e45f6d 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -33,7 +33,7 @@ from mlx.nn.layers.activations import ( from mlx.nn.layers.base import Module from mlx.nn.layers.containers import Sequential from mlx.nn.layers.convolution import Conv1d, Conv2d -from mlx.nn.layers.dropout import Dropout +from mlx.nn.layers.dropout import Dropout, Dropout2d from mlx.nn.layers.embedding import Embedding from mlx.nn.layers.linear import Linear from mlx.nn.layers.normalization import GroupNorm, LayerNorm, RMSNorm diff --git a/python/mlx/nn/layers/dropout.py b/python/mlx/nn/layers/dropout.py index 3193cdbd7..e2cc981e2 100644 --- a/python/mlx/nn/layers/dropout.py +++ b/python/mlx/nn/layers/dropout.py @@ -32,4 +32,61 @@ class Dropout(Module): mask = mx.random.bernoulli(self._p_1, x.shape) - return (1 / self._p_1) * mask.astype(x.dtype) * x + return (1 / self._p_1) * mask * x + + +class Dropout2d(Module): + """Apply 2D channel-wise dropout during training. + + Randomly zero out entire channels independently with probability :math:`p`. + This layer expects the channels to be last, i.e. the input shape should be + ``NWHC`` or ``WHC`` where: + - ``N`` is the batch dimension + - ``H`` is the input image height + - ``W`` is the input image width + - ``C`` is the number of input channels + + The remaining channels are scaled by :math:`\frac{1}{1-p}` to + maintain the expected value of each element. Unlike traditional dropout, + which zeros individual entries, this layer zeros entire channels. This is + beneficial for early convolution layers where adjacent pixels are + correlated. In such case, traditional dropout may not effectively + regularize activations. For more details, see [1]. + + [1]: Thompson, J., Goroshin, R., Jain, A., LeCun, Y. and Bregler C., 2015. + Efficient Object Localization Using Convolutional Networks. CVPR 2015. + + Args: + p (float): Probability of zeroing a channel during training. + """ + + def __init__(self, p: float = 0.5): + super().__init__() + + if p < 0 or p >= 1: + raise ValueError("The dropout probability should be in [0, 1)") + + self._p_1 = 1 - p + + def _extra_repr(self): + return f"p={1-self._p_1}" + + def __call__(self, x): + if x.ndim not in (3, 4): + raise ValueError( + f"Received input with {x.ndim} dimensions. Expected 3 or 4 dimensions." + ) + + if self._p_1 == 1 or not self.training: + return x + + # Dropout is applied on the whole channel + # 3D input: (1, 1, C) + # 4D input: (B, 1, 1, C) + mask_shape = x.shape + mask_shape[-2] = mask_shape[-3] = 1 + + mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape) + return (1 / self._p_1) * mask * x + + From d35fa1db41b1d883998e73c6bb8f96cc1490ee03 Mon Sep 17 00:00:00 2001 From: Nicholas Santavas Date: Fri, 22 Dec 2023 19:28:10 +0100 Subject: [PATCH 05/37] Add Hinge, Huber and LogCosh losses (#199) --- docs/src/python/nn.rst | 4 +- docs/src/python/nn/losses.rst | 5 +- python/mlx/nn/losses.py | 93 +++++++++++++++++++++++++++++++++++ python/tests/test_nn.py | 18 +++++++ 4 files changed, 117 insertions(+), 3 deletions(-) diff --git a/docs/src/python/nn.rst b/docs/src/python/nn.rst index bc19a8162..4c9868171 100644 --- a/docs/src/python/nn.rst +++ b/docs/src/python/nn.rst @@ -123,7 +123,7 @@ To get more detailed information on the arrays in a :class:`Module` you can use all the parameters in a :class:`Module` do: .. code-block:: python - + from mlx.utils import tree_map shapes = tree_map(lambda p: p.shape, mlp.parameters()) @@ -131,7 +131,7 @@ As another example, you can count the number of parameters in a :class:`Module` with: .. code-block:: python - + from mlx.utils import tree_flatten num_params = sum(v.size for _, v in tree_flatten(mlp.parameters())) diff --git a/docs/src/python/nn/losses.rst b/docs/src/python/nn/losses.rst index b6a202d4a..3fb7589f8 100644 --- a/docs/src/python/nn/losses.rst +++ b/docs/src/python/nn/losses.rst @@ -16,4 +16,7 @@ Loss Functions mse_loss nll_loss smooth_l1_loss - triplet_loss \ No newline at end of file + triplet_loss + hinge_loss + huber_loss + log_cosh_loss \ No newline at end of file diff --git a/python/mlx/nn/losses.py b/python/mlx/nn/losses.py index 755656e4f..35aedf755 100644 --- a/python/mlx/nn/losses.py +++ b/python/mlx/nn/losses.py @@ -1,5 +1,7 @@ # Copyright © 2023 Apple Inc. +import math + import mlx.core as mx from mlx.nn.layers.base import Module @@ -283,3 +285,94 @@ def _reduce(loss: mx.array, reduction: str = "none"): return loss else: raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.") + + +def hinge_loss( + inputs: mx.array, targets: mx.array, reduction: str = "none" +) -> mx.array: + """ + Computes the hinge loss between inputs and targets. + + .. math:: + + \text{hinge}(y, y_{\text{pred}}) = \max(0, 1 - y \cdot y_{\text{pred}}) + + + Args: + inputs (array): The predicted values. + targets (array): The target values. They should be -1 or 1. + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``. + + Returns: + array: The computed hinge loss. + """ + loss = mx.maximum(1 - inputs * targets, 0) + + return _reduce(loss, reduction) + + +def huber_loss( + inputs: mx.array, targets: mx.array, delta: float = 1.0, reduction: str = "none" +) -> mx.array: + """ + Computes the Huber loss between inputs and targets. + + .. math:: + + L_{\delta}(a) = + \left\{ \begin{array}{ll} + \frac{1}{2} a^2 & \text{for } |a| \leq \delta, \\ + \delta \left( |a| - \frac{1}{2} \delta \right) & \text{otherwise.} + \end{array} \right. + + Args: + inputs (array): The predicted values. + targets (array): The target values. + delta (float, optional): The threshold at which to change between L1 and L2 loss. + Default: ``1.0``. + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``. + + Returns: + array: The computed Huber loss. + """ + errors = inputs - targets + abs_errors = mx.abs(errors) + quadratic = mx.minimum(abs_errors, delta) + linear = abs_errors - quadratic + loss = 0.5 * quadratic**2 + delta * linear + + return _reduce(loss, reduction) + + +def log_cosh_loss( + inputs: mx.array, targets: mx.array, reduction: str = "none" +) -> mx.array: + """ + Computes the log cosh loss between inputs and targets. + + Logcosh acts like L2 loss for small errors, ensuring stable gradients, + and like the L1 loss for large errors, reducing sensitivity to outliers. This + dual behavior offers a balanced, robust approach for regression tasks. + + .. math:: + + \text{logcosh}(y_{\text{true}}, y_{\text{pred}}) = + \frac{1}{n} \sum_{i=1}^{n} + \log(\cosh(y_{\text{pred}}^{(i)} - y_{\text{true}}^{(i)})) + + + Args: + inputs (array): The predicted values. + targets (array): The target values. + reduction (str, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``. + + Returns: + array: The computed log cosh loss. + """ + errors = inputs - targets + loss = mx.logaddexp(errors, -errors) - math.log(2) + + return _reduce(loss, reduction) diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index ebc6f2b7a..0d1c8b2ff 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -581,6 +581,24 @@ class TestNN(mlx_tests.MLXTestCase): y = alibi(x.astype(mx.float16)) self.assertTrue(y.dtype, mx.float16) + def test_hinge_loss(self): + inputs = mx.ones((2, 4)) + targets = mx.zeros((2, 4)) + loss = nn.losses.hinge_loss(inputs, targets, reduction="mean") + self.assertEqual(loss, 1.0) + + def test_huber_loss(self): + inputs = mx.ones((2, 4)) + targets = mx.zeros((2, 4)) + loss = nn.losses.huber_loss(inputs, targets, reduction="mean") + self.assertEqual(loss, 0.5) + + def test_log_cosh_loss(self): + inputs = mx.ones((2, 4)) + targets = mx.zeros((2, 4)) + loss = nn.losses.log_cosh_loss(inputs, targets, reduction="mean") + self.assertEqual(loss, 0.433781) + if __name__ == "__main__": unittest.main() From cd3616a463a594a1679bd9dc823071bed1469c51 Mon Sep 17 00:00:00 2001 From: Ronan Collobert Date: Fri, 22 Dec 2023 11:01:26 -0800 Subject: [PATCH 06/37] Revisit autorelease memory pools (#260) * make general autorelease pool part of metal device * make things simpler * no metal backend support * new_memory_pool -> new_scoped_memory_pool --- mlx/backend/metal/device.cpp | 41 ++++++++++++++++++---------------- mlx/backend/metal/device.h | 2 -- mlx/backend/metal/metal.cpp | 7 +----- mlx/backend/metal/metal.h | 1 + mlx/backend/no_metal/metal.cpp | 3 +++ mlx/scheduler.h | 1 + 6 files changed, 28 insertions(+), 27 deletions(-) diff --git a/mlx/backend/metal/device.cpp b/mlx/backend/metal/device.cpp index 6b6158f29..c48f2908f 100644 --- a/mlx/backend/metal/device.cpp +++ b/mlx/backend/metal/device.cpp @@ -17,10 +17,11 @@ namespace fs = std::filesystem; namespace mlx::core::metal { -static Device metal_device_; - namespace { +// Catch things related to the main-thread static variables +static std::shared_ptr global_memory_pool = new_scoped_memory_pool(); + // TODO nicer way to set this or possibly expose as an environment variable static constexpr int MAX_BUFFERS_PER_QUEUE = 12; @@ -112,29 +113,29 @@ MTL::Library* load_library( } // namespace -Device::Device() - : pool_(NS::AutoreleasePool::alloc()->init()), - device_(load_device()), - library_map_({{"mlx", load_library(device_)}}) {} +Device::Device() { + auto pool = new_scoped_memory_pool(); + device_ = load_device(); + library_map_ = {{"mlx", load_library(device_)}}; +} Device::~Device() { for (auto& q : queue_map_) { q.second->release(); } - for (auto& k : kernel_map_) { - k.second->release(); - } - for (auto& l : library_map_) { - l.second->release(); - } for (auto& b : buffer_map_) { b.second.second->release(); } for (auto& e : encoder_map_) { e.second->release(); } + for (auto& k : kernel_map_) { + k.second->release(); + } + for (auto& l : library_map_) { + l.second->release(); + } device_->release(); - pool_->release(); } void Device::new_queue(int index) { @@ -243,6 +244,7 @@ void Device::register_library( MTL::ComputePipelineState* Device::get_kernel( const std::string& name, const std::string& lib_name /* = "mlx" */) { + auto pool = new_scoped_memory_pool(); // Look for cached kernel if (auto it = kernel_map_.find(name); it != kernel_map_.end()) { return it->second; @@ -285,17 +287,18 @@ MTL::ComputePipelineState* Device::get_kernel( } Device& device(mlx::core::Device) { - return metal_device_; + static Device metal_device; + return metal_device; } -NS::AutoreleasePool*& thread_autorelease_pool() { - static thread_local NS::AutoreleasePool* p = - NS::AutoreleasePool::alloc()->init(); - return p; +std::shared_ptr new_scoped_memory_pool() { + auto dtor = [](void* ptr) { + static_cast(ptr)->release(); + }; + return std::shared_ptr(NS::AutoreleasePool::alloc()->init(), dtor); } void new_stream(Stream stream) { - thread_autorelease_pool(); if (stream.device == mlx::core::Device::gpu) { device(stream.device).new_queue(stream.index); } diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h index 62675d430..45449a332 100644 --- a/mlx/backend/metal/device.h +++ b/mlx/backend/metal/device.h @@ -67,7 +67,6 @@ class Device { const std::vector& arg_descs) const; private: - NS::AutoreleasePool* pool_; MTL::Device* device_; std::unordered_map queue_map_; std::unordered_map> buffer_map_; @@ -78,6 +77,5 @@ class Device { }; Device& device(mlx::core::Device); -NS::AutoreleasePool*& thread_autorelease_pool(); } // namespace mlx::core::metal diff --git a/mlx/backend/metal/metal.cpp b/mlx/backend/metal/metal.cpp index f63ad55a3..478e57c73 100644 --- a/mlx/backend/metal/metal.cpp +++ b/mlx/backend/metal/metal.cpp @@ -50,6 +50,7 @@ std::function make_task( bool retain_graph) { auto task = [retain_graph, arr, deps = std::move(deps), p = std::move(p)]() mutable { + auto pool = new_scoped_memory_pool(); for (auto& d : deps) { d.wait(); } @@ -66,12 +67,6 @@ std::function make_task( arr.detach(); } p->set_value(); - // Signal this thread to clear the pool on a synchroniztion. - scheduler::enqueue(s, []() { - thread_autorelease_pool()->release(); - thread_autorelease_pool() = - NS::AutoreleasePool::alloc()->init(); - }); scheduler::notify_task_completion(s); }); metal::device(s.device).commit_command_buffer(s.index); diff --git a/mlx/backend/metal/metal.h b/mlx/backend/metal/metal.h index f1f7ede44..99f400956 100644 --- a/mlx/backend/metal/metal.h +++ b/mlx/backend/metal/metal.h @@ -20,6 +20,7 @@ constexpr bool is_available() { } void new_stream(Stream stream); +std::shared_ptr new_scoped_memory_pool(); std::function make_task( array& arr, diff --git a/mlx/backend/no_metal/metal.cpp b/mlx/backend/no_metal/metal.cpp index accfc4c8a..b3a7dc41c 100644 --- a/mlx/backend/no_metal/metal.cpp +++ b/mlx/backend/no_metal/metal.cpp @@ -7,6 +7,9 @@ namespace mlx::core::metal { void new_stream(Stream) {} +std::shared_ptr new_memory_pool() { + return nullptr; +} std::function make_task( array& arr, diff --git a/mlx/scheduler.h b/mlx/scheduler.h index 6506b20ab..150cc96db 100644 --- a/mlx/scheduler.h +++ b/mlx/scheduler.h @@ -35,6 +35,7 @@ struct StreamThread { } void thread_fn() { + auto thread_pool = metal::new_scoped_memory_pool(); metal::new_stream(stream); while (true) { std::function task; From f91f45014199e35a2c4578ced2a5350890ee1368 Mon Sep 17 00:00:00 2001 From: Finn Voorhees Date: Fri, 22 Dec 2023 23:33:17 -0500 Subject: [PATCH 07/37] Fix argmax returns documentation (#263) --- python/src/ops.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 23a6ec2c6..277ef596b 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -2129,7 +2129,7 @@ void init_ops(py::module_& m) { singleton dimensions, defaults to `False`. Returns: - array: The output array with the indices of the minimum values. + array: The output array with the indices of the maximum values. )pbdoc"); m.def( "sort", From acf1721b982170f4526b1ca4ed6aef95c66a4e95 Mon Sep 17 00:00:00 2001 From: Vidit Agarwal Date: Sun, 24 Dec 2023 00:36:38 +0530 Subject: [PATCH 08/37] Corrected the example of value_and_grad (#274) * Corrected the example for mx.value_and_grad * Reformat through pre-commit/black --- python/mlx/nn/layers/dropout.py | 2 -- python/src/transforms.cpp | 4 ++-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/python/mlx/nn/layers/dropout.py b/python/mlx/nn/layers/dropout.py index e2cc981e2..14c5cb15e 100644 --- a/python/mlx/nn/layers/dropout.py +++ b/python/mlx/nn/layers/dropout.py @@ -88,5 +88,3 @@ class Dropout2d(Module): mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape) return (1 / self._p_1) * mask * x - - diff --git a/python/src/transforms.cpp b/python/src/transforms.cpp index a592b4458..096d5a486 100644 --- a/python/src/transforms.cpp +++ b/python/src/transforms.cpp @@ -569,7 +569,7 @@ void init_transforms(py::module_& m) { return lvalue # Returns lvalue, dlvalue/dparams - lvalue, grads = mx.value_and_grad(mse) + lvalue, grads = mx.value_and_grad(mse)(params, inputs, targets) def lasso(params, inputs, targets, a=1.0, b=1.0): outputs = forward(params, inputs) @@ -580,7 +580,7 @@ void init_transforms(py::module_& m) { return loss, mse, l1 - (loss, mse, l1), grads = mx.value_and_grad(lasso) + (loss, mse, l1), grads = mx.value_and_grad(lasso)(params, inputs, targets) Args: fun (function): A function which takes a variable number of From 8c3da54c7d4e6ffc061ec5985e7c0037d7e72110 Mon Sep 17 00:00:00 2001 From: Vidit Agarwal Date: Sun, 24 Dec 2023 05:56:46 +0530 Subject: [PATCH 09/37] Fix failing test for log cosh loss (#275) * fix assert statement in log_cosh_loss * reformatted by pre-commit black --- python/tests/test_nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 0d1c8b2ff..cc56bc430 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -597,7 +597,7 @@ class TestNN(mlx_tests.MLXTestCase): inputs = mx.ones((2, 4)) targets = mx.zeros((2, 4)) loss = nn.losses.log_cosh_loss(inputs, targets, reduction="mean") - self.assertEqual(loss, 0.433781) + self.assertAlmostEqual(loss.item(), 0.433781, places=6) if __name__ == "__main__": From 8b227fa9afe8f87dd9aca99f82d508e313de65bb Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sat, 23 Dec 2023 19:18:10 -0800 Subject: [PATCH 10/37] fix no metal build (#276) --- mlx/backend/no_metal/metal.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx/backend/no_metal/metal.cpp b/mlx/backend/no_metal/metal.cpp index b3a7dc41c..212ca2839 100644 --- a/mlx/backend/no_metal/metal.cpp +++ b/mlx/backend/no_metal/metal.cpp @@ -7,7 +7,7 @@ namespace mlx::core::metal { void new_stream(Stream) {} -std::shared_ptr new_memory_pool() { +std::shared_ptr new_scoped_memory_pool() { return nullptr; } From 7365d142a3a4461d383342d899846e9252fdfc81 Mon Sep 17 00:00:00 2001 From: Daniel Strobusch <1847260+dastrobu@users.noreply.github.com> Date: Sun, 24 Dec 2023 16:04:43 +0100 Subject: [PATCH 11/37] random.uniform must respect dtype, even if lower precision than "low" (#280) Fix an edge case where random uniform returns a float32 array, even if a lower precision dtype is wanted due to adding the float32 "low" array. --- mlx/random.cpp | 6 ++++-- python/tests/test_random.py | 3 +++ tests/random_tests.cpp | 4 ++++ 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/mlx/random.cpp b/mlx/random.cpp index 232c458f9..ef11f8c65 100644 --- a/mlx/random.cpp +++ b/mlx/random.cpp @@ -103,7 +103,9 @@ array uniform( } auto stream = to_stream(s); - auto range = subtract(high, low, stream); + auto lo = astype(low, dtype, stream); + auto hi = astype(high, dtype, stream); + auto range = subtract(hi, lo, stream); auto out_shape = broadcast_shapes(shape, range.shape()); if (out_shape != shape) { std::ostringstream msg; @@ -136,7 +138,7 @@ array uniform( auto out = bits(shape, size_of(dtype), key, stream); out = astype(divide(out, maxval, stream), dtype, stream); out = minimum(out, upper, stream); - return add(multiply(range, out, stream), low, stream); + return add(multiply(range, out, stream), lo, stream); } array uniform( diff --git a/python/tests/test_random.py b/python/tests/test_random.py index 1603371b3..aa01339f4 100644 --- a/python/tests/test_random.py +++ b/python/tests/test_random.py @@ -58,6 +58,9 @@ class TestRandom(mlx_tests.MLXTestCase): a = mx.random.uniform(shape=(1000,), low=mx.array(-1), high=5) self.assertTrue(mx.all((a > -1) < 5).item()) + a = mx.random.uniform(low=-0.1, high=0.1, shape=(1,), dtype=mx.bfloat16) + self.assertEqual(a.dtype, mx.bfloat16) + def test_normal(self): key = mx.random.key(0) a = mx.random.normal(key=key) diff --git a/tests/random_tests.cpp b/tests/random_tests.cpp index 1a387febc..b7793e41c 100644 --- a/tests/random_tests.cpp +++ b/tests/random_tests.cpp @@ -260,6 +260,10 @@ TEST_CASE("test random uniform") { // Non float type throws CHECK_THROWS_AS(random::uniform({}, int32), std::invalid_argument); + // dtype respected + x = random::uniform(-.1, .1, {0}, bfloat16); + CHECK_EQ(x.dtype(), bfloat16); + // Check broadcasting x = random::uniform(zeros({3, 1}), ones({1, 3}), {3, 3}); CHECK_EQ(x.shape(), std::vector{3, 3}); From 22fee5a38354d2d86d8a44d2bfe8a9f0e3b4d67c Mon Sep 17 00:00:00 2001 From: Zach Schillaci Date: Sun, 24 Dec 2023 11:39:08 -0500 Subject: [PATCH 12/37] Remove redundant assert in losses.py (#281) --- python/mlx/nn/losses.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/python/mlx/nn/losses.py b/python/mlx/nn/losses.py index 35aedf755..cfb6ffa15 100644 --- a/python/mlx/nn/losses.py +++ b/python/mlx/nn/losses.py @@ -133,10 +133,6 @@ def mse_loss( f"targets shape {targets.shape}." ) - assert ( - predictions.shape == targets.shape - ), f"Shape of predictions {predictions.shape} and targets {targets.shape} must match" - loss = mx.square(predictions - targets) return _reduce(loss, reduction) From 9e6b8c9f48b85d8d345734375aa8fad0fa5eade1 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Sun, 24 Dec 2023 14:47:57 -0800 Subject: [PATCH 13/37] Refactor the reduction kernels (#277) --- benchmarks/python/comparative/compare.py | 8 + mlx/backend/common/reduce.h | 2 +- mlx/backend/metal/kernels/reduce.metal | 280 ++++++----------------- mlx/backend/metal/reduce.cpp | 258 +++++++++------------ 4 files changed, 179 insertions(+), 369 deletions(-) diff --git a/benchmarks/python/comparative/compare.py b/benchmarks/python/comparative/compare.py index c54af3a46..4adde50bc 100644 --- a/benchmarks/python/comparative/compare.py +++ b/benchmarks/python/comparative/compare.py @@ -125,6 +125,14 @@ if __name__ == "__main__": compare_filtered("sum_axis --size 16x128x1024 --axis 1") compare_filtered("sum_axis --size 16x128x1024 --axis 0 --cpu") compare_filtered("sum_axis --size 16x128x1024 --axis 0") + compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --cpu") + compare_filtered("sum_axis --size 16x128x1024 --axis 0,1") + compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --cpu") + compare_filtered("sum_axis --size 16x128x1024 --axis 0,2") + compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --transpose 0,2,1 --cpu") + compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --transpose 0,2,1") + compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --transpose 0,2,1 --cpu") + compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --transpose 0,2,1") compare_filtered("argmax --size 10x1024x128 --axis 1 --cpu") compare_filtered("argmax --size 10x1024x128 --axis 1") compare_filtered("argmax --size 10x1024x128 --axis 2 --cpu") diff --git a/mlx/backend/common/reduce.h b/mlx/backend/common/reduce.h index 740f54a48..da1d1658a 100644 --- a/mlx/backend/common/reduce.h +++ b/mlx/backend/common/reduce.h @@ -126,7 +126,7 @@ struct ReductionPlan { ReductionPlan get_reduction_plan(const array& x, const std::vector axes) { // The data is all there and we are reducing over everything if (x.size() == x.data_size() && axes.size() == x.ndim() && - (x.flags().row_contiguous || x.flags().col_contiguous)) { + x.flags().contiguous) { return ContiguousAllReduce; } diff --git a/mlx/backend/metal/kernels/reduce.metal b/mlx/backend/metal/kernels/reduce.metal index 25bf1ee1f..85ff41f44 100644 --- a/mlx/backend/metal/kernels/reduce.metal +++ b/mlx/backend/metal/kernels/reduce.metal @@ -112,88 +112,33 @@ template uint simd_group_id [[simdgroup_index_in_threadgroup]]); -/////////////////////////////////////////////////////////////////////////////// -// General reduce -/////////////////////////////////////////////////////////////////////////////// - -template -[[kernel]] void general_reduce( - const device T *in [[buffer(0)]], - device mlx_atomic *out [[buffer(1)]], - const device int *in_shape [[buffer(2)]], - const device size_t *in_strides [[buffer(3)]], - const device size_t *out_strides [[buffer(4)]], - const device size_t& ndim [[buffer(5)]], - uint gid [[thread_position_in_grid]]) { - Op op; - auto in_idx = elem_to_loc(gid, in_shape, in_strides, ndim); - auto out_idx = elem_to_loc(gid, in_shape, out_strides, ndim); - op.atomic_update(out, static_cast(in[in_idx]), out_idx); -} - -template -[[kernel]] void general_reduce( - const device T *in [[buffer(0)]], - device mlx_atomic *out [[buffer(1)]], - const device int *in_shape [[buffer(2)]], - const device size_t *in_strides [[buffer(3)]], - const device size_t *out_strides [[buffer(4)]], - uint gid [[thread_position_in_grid]]) { - Op op; - auto in_idx = elem_to_loc_nd(gid, in_shape, in_strides); - auto out_idx = elem_to_loc_nd(gid, in_shape, out_strides); - op.atomic_update(out, static_cast(in[in_idx]), out_idx); -} - -#define instantiate_general_reduce_helper(name, itype, otype, op) \ - template [[host_name("general_reduce_" #name)]] \ - [[kernel]] void general_reduce( \ - const device itype *in [[buffer(0)]], \ - device mlx_atomic *out [[buffer(1)]], \ - const device int *in_shape [[buffer(2)]], \ - const device size_t *in_strides [[buffer(3)]], \ - const device size_t *out_strides [[buffer(4)]], \ - const device size_t& ndim [[buffer(5)]], \ - uint gid [[thread_position_in_grid]]); - -#define instantiate_general_reduce_helper_nd(name, itype, otype, op, n) \ - template [[host_name("general_reduce_" #name "_dim_" #n)]] \ - [[kernel]] void general_reduce( \ - const device itype *in [[buffer(0)]], \ - device mlx_atomic *out [[buffer(1)]], \ - const device int *in_shape [[buffer(2)]], \ - const device size_t *in_strides [[buffer(3)]], \ - const device size_t *out_strides [[buffer(4)]], \ - uint gid [[thread_position_in_grid]]); - -#define instantiate_general_reduce(name, itype, otype, op) \ - instantiate_general_reduce_helper(name, itype, otype, op) \ - instantiate_general_reduce_helper_nd(name, itype, otype, op, 1) \ - instantiate_general_reduce_helper_nd(name, itype, otype, op, 2) \ - instantiate_general_reduce_helper_nd(name, itype, otype, op, 3) \ - instantiate_general_reduce_helper_nd(name, itype, otype, op, 4) - - /////////////////////////////////////////////////////////////////////////////// // Row atomics /////////////////////////////////////////////////////////////////////////////// template -[[kernel]] void row_reduce( +[[kernel]] void row_reduce_general( const device T *in [[buffer(0)]], - device U *out [[buffer(1)]], - const device size_t& reduction_size [[buffer(2)]], - uint lid [[thread_position_in_threadgroup]], - uint lsize [[threads_per_threadgroup]], - uint tid [[threadgroup_position_in_grid]], + device mlx_atomic *out [[buffer(1)]], + const constant size_t& reduction_size [[buffer(2)]], + const constant size_t& out_size [[buffer(3)]], + const constant int* shape [[buffer(4)]], + const constant size_t* strides [[buffer(5)]], + const constant int& ndim [[buffer(6)]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 lsize [[threads_per_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_per_group [[simdgroups_per_threadgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { Op op; - // Each threadgroup handles 1 reduction - in += tid * reduction_size + lid * N_READS; + // Each threadgroup handles 1 reduction + // TODO: Specializing elem_to_loc would be slightly faster + int idx = tid.y * out_size + tid.x; + int extra_offset = elem_to_loc(idx, shape, strides, ndim); + in += extra_offset + lid.x * N_READS; // The reduction is accumulated here U total_val = Op::init; @@ -201,7 +146,7 @@ template // Loop over the reduction size within thread group int r = 0; - for (; r < (int)ceildiv(reduction_size, N_READS*lsize) - 1; r++) { + for (; r < (int)ceildiv(reduction_size, N_READS*lsize.x) - 1; r++) { T vals[N_READS]; for(int i = 0; i < N_READS; i++) { vals[i] = in[i]; @@ -210,11 +155,11 @@ template total_val = op(static_cast(vals[i]), total_val); } - in += lsize * N_READS; + in += lsize.x * N_READS; } - // Sepate case for the last set as we close the reduction size - size_t reduction_index = (lid + (size_t)lsize * r) * N_READS; + // Separate case for the last set as we close the reduction size + size_t reduction_index = (lid.x + (size_t)lsize.x * r) * N_READS; if(reduction_index < reduction_size) { int max_reads = reduction_size - reduction_index; @@ -240,26 +185,30 @@ template // Reduction within thread group // Only needed if multiple simd groups if(reduction_size > simd_size) { - total_val = lid < simd_per_group ? local_vals[lid] : op.init; + total_val = lid.x < simd_per_group ? local_vals[lid.x] : op.init; total_val = op.simd_reduce(total_val); } // Update output - if (lid == 0) { - out[tid] = total_val; + if (lid.x == 0) { + op.atomic_update(out, total_val, tid.x); } } -#define instantiate_row_reduce(name, itype, otype, op) \ - template [[host_name("row_reduce_" #name)]] \ - [[kernel]] void row_reduce( \ - const device itype *in [[buffer(0)]], \ - device otype *out [[buffer(1)]], \ - const device size_t& reduction_size [[buffer(2)]], \ - uint lid [[thread_position_in_threadgroup]], \ - uint lsize [[threads_per_threadgroup]], \ - uint tid [[threadgroup_position_in_grid]], \ - uint simd_lane_id [[thread_index_in_simdgroup]], \ - uint simd_per_group [[simdgroups_per_threadgroup]], \ +#define instantiate_row_reduce_general(name, itype, otype, op) \ + template [[host_name("row_reduce_general_" #name)]] \ + [[kernel]] void row_reduce_general( \ + const device itype *in [[buffer(0)]], \ + device mlx_atomic *out [[buffer(1)]], \ + const constant size_t& reduction_size [[buffer(2)]], \ + const constant size_t& out_size [[buffer(3)]], \ + const constant int* shape [[buffer(4)]], \ + const constant size_t* strides [[buffer(5)]], \ + const constant int& ndim [[buffer(6)]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint3 lsize [[threads_per_threadgroup]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_per_group [[simdgroups_per_threadgroup]], \ uint simd_group_id [[simdgroup_index_in_threadgroup]]); @@ -311,148 +260,57 @@ inline void _contiguous_strided_reduce( } template -[[kernel]] void col_reduce( +[[kernel]] void col_reduce_general( const device T *in [[buffer(0)]], device mlx_atomic *out [[buffer(1)]], const constant size_t& reduction_size [[buffer(2)]], const constant size_t& reduction_stride [[buffer(3)]], const constant size_t& out_size [[buffer(4)]], + const constant int* shape [[buffer(5)]], + const constant size_t* strides [[buffer(6)]], + const constant int& ndim [[buffer(7)]], threadgroup U *local_data [[threadgroup(0)]], - uint2 tid [[threadgroup_position_in_grid]], - uint2 lid [[thread_position_in_threadgroup]], - uint2 lsize [[threads_per_threadgroup]]) { - auto out_idx = tid.x * lsize.x + lid.x; - + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 lsize [[threads_per_threadgroup]]) { + auto out_idx = tid.x * lsize.x + lid.x; + auto in_idx = elem_to_loc( + out_idx + tid.z * out_size, + shape, + strides, + ndim + ); + if(out_idx < out_size) { _contiguous_strided_reduce( in, out, local_data, - out_idx, + in_idx, out_idx, reduction_size, reduction_stride, - tid, - lid, - lsize); + tid.xy, + lid.xy, + lsize.xy); } } -#define instantiate_col_reduce(name, itype, otype, op) \ - template [[host_name("col_reduce_" #name)]] \ - [[kernel]] void col_reduce( \ +#define instantiate_col_reduce_general(name, itype, otype, op) \ + template [[host_name("col_reduce_general_" #name)]] \ + [[kernel]] void col_reduce_general( \ const device itype *in [[buffer(0)]], \ device mlx_atomic *out [[buffer(1)]], \ const constant size_t& reduction_size [[buffer(2)]], \ const constant size_t& reduction_stride [[buffer(3)]], \ const constant size_t& out_size [[buffer(4)]], \ + const constant int* shape [[buffer(5)]], \ + const constant size_t* strides [[buffer(6)]], \ + const constant int& ndim [[buffer(7)]], \ threadgroup otype *local_data [[threadgroup(0)]], \ - uint2 tid [[threadgroup_position_in_grid]], \ - uint2 lid [[thread_position_in_threadgroup]], \ - uint2 lsize [[threads_per_threadgroup]]); - -template -[[kernel]] void contiguous_strided_reduce( - const device T *in [[buffer(0)]], - device mlx_atomic *out [[buffer(1)]], - const constant size_t& reduction_size [[buffer(2)]], - const constant size_t& reduction_stride [[buffer(3)]], - const constant size_t& out_size [[buffer(4)]], - const device int* in_shape [[buffer(5)]], - const device size_t* in_strides [[buffer(6)]], - threadgroup U *local_data [[threadgroup(0)]], - uint2 tid [[threadgroup_position_in_grid]], - uint2 lid [[thread_position_in_threadgroup]], - uint2 lsize [[threads_per_threadgroup]]) { - - auto out_idx = tid.x * lsize.x + lid.x; - auto in_idx = elem_to_loc_nd(out_idx, in_shape, in_strides); - - if(out_idx < out_size) { - _contiguous_strided_reduce( - in, - out, - local_data, - in_idx, - out_idx, - reduction_size, - reduction_stride, - tid, - lid, - lsize); - } -} - -template -[[kernel]] void contiguous_strided_reduce( - const device T *in [[buffer(0)]], - device mlx_atomic *out [[buffer(1)]], - const constant size_t& reduction_size [[buffer(2)]], - const constant size_t& reduction_stride [[buffer(3)]], - const constant size_t& out_size [[buffer(4)]], - const device int* in_shape [[buffer(5)]], - const device size_t* in_strides [[buffer(6)]], - const device size_t& in_dim [[buffer(7)]], - threadgroup U *local_data [[threadgroup(0)]], - uint2 tid [[threadgroup_position_in_grid]], - uint2 lid [[thread_position_in_threadgroup]], - uint2 lsize [[threads_per_threadgroup]]) { - - auto out_idx = tid.x * lsize.x + lid.x; - auto in_idx = elem_to_loc(out_idx, in_shape, in_strides, in_dim); - - if(out_idx < out_size) { - _contiguous_strided_reduce( - in, - out, - local_data, - in_idx, - out_idx, - reduction_size, - reduction_stride, - tid, - lid, - lsize); - } -} - -#define instantiate_contiguous_strided_helper(name, itype, otype, op) \ - template [[host_name("contiguous_strided_reduce_" #name)]] \ - [[kernel]] void contiguous_strided_reduce( \ - const device itype *in [[buffer(0)]], \ - device mlx_atomic *out [[buffer(1)]], \ - const constant size_t& reduction_size [[buffer(2)]], \ - const constant size_t& reduction_stride [[buffer(3)]], \ - const constant size_t& out_size [[buffer(4)]], \ - const device int* in_shape [[buffer(5)]], \ - const device size_t* in_strides [[buffer(6)]], \ - const device size_t& in_dim [[buffer(7)]], \ - threadgroup otype *local_data [[threadgroup(0)]], \ - uint2 tid [[threadgroup_position_in_grid]], \ - uint2 lid [[thread_position_in_threadgroup]], \ - uint2 lsize [[threads_per_threadgroup]]); - -#define instantiate_contiguous_strided_helper_nd(name, itype, otype, op, n) \ - template [[host_name("contiguous_strided_reduce_" #name "_dim_" #n)]] \ - [[kernel]] void contiguous_strided_reduce( \ - const device itype *in [[buffer(0)]], \ - device mlx_atomic *out [[buffer(1)]], \ - const constant size_t& reduction_size [[buffer(2)]], \ - const constant size_t& reduction_stride [[buffer(3)]], \ - const constant size_t& out_size [[buffer(4)]], \ - const device int* in_shape [[buffer(5)]], \ - const device size_t* in_strides [[buffer(6)]], \ - threadgroup otype *local_data [[threadgroup(0)]], \ - uint2 tid [[threadgroup_position_in_grid]], \ - uint2 lid [[thread_position_in_threadgroup]], \ - uint2 lsize [[threads_per_threadgroup]]); - -#define instantiate_contiguous_strided(name, itype, otype, op) \ - instantiate_contiguous_strided_helper(name, itype, otype, op) \ - instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 1) \ - instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 2) \ - instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 3) \ - instantiate_contiguous_strided_helper_nd(name, itype, otype, op, 4) + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint3 lsize [[threads_per_threadgroup]]); /////////////////////////////////////////////////////////////////////////////// @@ -461,10 +319,8 @@ template #define instantiate_reduce(name, itype, otype, op) \ instantiate_all_reduce(name, itype, otype, op) \ - instantiate_row_reduce(name, itype, otype, op) \ - instantiate_col_reduce(name, itype, otype, op) \ - instantiate_contiguous_strided(name, itype, otype, op) \ - instantiate_general_reduce(name, itype, otype, op) + instantiate_row_reduce_general(name, itype, otype, op) \ + instantiate_col_reduce_general(name, itype, otype, op) #define instantiate_same_reduce(name, tname, type, op) \ instantiate_init_reduce(name ##tname, type, op) \ @@ -535,4 +391,4 @@ instantiate_same_reduce(max_, float16, half, Max) instantiate_same_reduce(max_, float32, float, Max) instantiate_same_reduce(min_, bfloat16, bfloat16_t, Min) -instantiate_same_reduce(max_, bfloat16, bfloat16_t, Max) \ No newline at end of file +instantiate_same_reduce(max_, bfloat16, bfloat16_t, Max) diff --git a/mlx/backend/metal/reduce.cpp b/mlx/backend/metal/reduce.cpp index 532f18353..6a2ce084b 100644 --- a/mlx/backend/metal/reduce.cpp +++ b/mlx/backend/metal/reduce.cpp @@ -2,9 +2,11 @@ #include #include +#include #include #include "mlx/backend/common/reduce.h" +#include "mlx/backend/metal/copy.h" #include "mlx/backend/metal/device.h" #include "mlx/backend/metal/kernels/defines.h" #include "mlx/backend/metal/utils.h" @@ -61,22 +63,47 @@ void all_reduce_dispatch( compute_encoder->dispatchThreads(grid_dims, group_dims); } -void row_reduce_dispatch( +void row_reduce_general_dispatch( const array& in, array& out, const std::string& op_name, - const std::vector& axes_, + const ReductionPlan& plan, + const std::vector& axes, MTL::ComputeCommandEncoder* compute_encoder, metal::Device& d) { - auto kernel = d.get_kernel("row_reduce_" + op_name + type_to_name(in)); + auto kernel = + d.get_kernel("row_reduce_general_" + op_name + type_to_name(in)); + // Prepare the arguments for the kernel int n_reads = REDUCE_N_READS; - size_t reduction_size = in.size() / out.size(); + size_t reduction_size = plan.shape.back(); + size_t out_size = out.size(); + auto shape = plan.shape; + auto strides = plan.strides; + shape.pop_back(); + strides.pop_back(); + size_t non_row_reductions = 1; + for (auto s : shape) { + non_row_reductions *= static_cast(s); + } + auto [rem_shape, rem_strides] = shapes_without_reduction_axes(in, axes); + for (auto s : rem_shape) { + shape.push_back(s); + } + for (auto s : rem_strides) { + strides.push_back(s); + } + int ndim = shape.size(); + // Set the arguments for the kernel compute_encoder->setComputePipelineState(kernel); set_array_buffer(compute_encoder, in, 0); set_array_buffer(compute_encoder, out, 1); compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); + compute_encoder->setBytes(&out_size, sizeof(size_t), 3); + compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 4); + compute_encoder->setBytes(strides.data(), strides.size() * sizeof(size_t), 5); + compute_encoder->setBytes(&ndim, sizeof(int), 6); // Each thread group is responsible for 1 output NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); @@ -91,92 +118,54 @@ void row_reduce_dispatch( // Launch enough thread groups for each output size_t n_threads = out.size() * thread_group_size; - MTL::Size grid_dims = MTL::Size(n_threads, 1, 1); + MTL::Size grid_dims = MTL::Size(n_threads, non_row_reductions, 1); MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); compute_encoder->dispatchThreads(grid_dims, group_dims); } -void col_reduce_dispatch( +void strided_reduce_general_dispatch( const array& in, array& out, const std::string& op_name, - const std::vector& axes_, + const ReductionPlan& plan, + const std::vector& axes, MTL::ComputeCommandEncoder* compute_encoder, metal::Device& d) { - std::ostringstream kernel_name; + auto kernel = + d.get_kernel("col_reduce_general_" + op_name + type_to_name(in)); - bool encode_in_shape = false; - bool encode_ndim = false; - - // If the slowest moving axis can be merged into the reductions, - // we call the column reduce kernel - // In this case, a linear index in the output corresponds to the - // linear index in the input where the reduction starts - if (axes_[axes_.size() - 1] == (axes_.size() - 1)) { - kernel_name << "col_reduce_" << op_name << type_to_name(in); - } - // Otherwise, while all the reduction axes can be merged, the mapping between - // indices in the output and input require resolving using shapes and strides - else { - kernel_name << "contiguous_strided_reduce_" << op_name << type_to_name(in); - encode_in_shape = true; - - // We check for a viable template with the required number of dimensions - // we only care about encoding non-reduced shapes and strides in the input - size_t non_reducing_dims = in.ndim() - axes_.size(); - if (non_reducing_dims >= 1 && - non_reducing_dims <= MAX_REDUCE_SPECIALIZED_DIMS) { - kernel_name << "_dim_" << non_reducing_dims; - } else { - encode_ndim = true; - } - } - - auto kernel = d.get_kernel(kernel_name.str()); - size_t in_size = in.size(); + // Prepare the arguments for the kernel + size_t reduction_size = plan.shape.back(); + size_t reduction_stride = plan.strides.back(); size_t out_size = out.size(); + auto shape = plan.shape; + auto strides = plan.strides; + shape.pop_back(); + strides.pop_back(); + size_t non_col_reductions = 1; + for (auto s : shape) { + non_col_reductions *= static_cast(s); + } + auto [rem_shape, rem_strides] = shapes_without_reduction_axes(in, axes); + for (auto s : rem_shape) { + shape.push_back(s); + } + for (auto s : rem_strides) { + strides.push_back(s); + } + int ndim = shape.size(); + // Set the arguments for the kernel compute_encoder->setComputePipelineState(kernel); set_array_buffer(compute_encoder, in, 0); set_array_buffer(compute_encoder, out, 1); - - // Calculate the number of inputs to reduce and the stride b/w them - size_t reduction_size = 1; - size_t in_ndim = in.ndim(); - size_t reduction_stride = in_size; - - for (int i : axes_) { - reduction_size *= in.shape(i); - reduction_stride = std::min(reduction_stride, in.strides()[i]); - } - compute_encoder->setBytes(&reduction_size, sizeof(size_t), 2); compute_encoder->setBytes(&reduction_stride, sizeof(size_t), 3); compute_encoder->setBytes(&out_size, sizeof(size_t), 4); - if (encode_in_shape) { - // Obtain the non-reducing shape and strides of the input to encode - std::vector inp_shape_mod; - std::vector inp_strides_mod; - - for (size_t i = 0, j = 0; i < in.ndim(); i++) { - if (j < axes_.size() && axes_[j] == i) { - j++; - } else { - inp_shape_mod.push_back(in.shape(i)); - inp_strides_mod.push_back(in.strides()[i]); - } - } - - size_t ndim = inp_shape_mod.size(); - - compute_encoder->setBytes(inp_shape_mod.data(), ndim * sizeof(int), 5); - compute_encoder->setBytes(inp_strides_mod.data(), ndim * sizeof(size_t), 6); - - if (encode_ndim) { - compute_encoder->setBytes(&ndim, sizeof(size_t), 7); - } - } + compute_encoder->setBytes(shape.data(), shape.size() * sizeof(int), 5); + compute_encoder->setBytes(strides.data(), strides.size() * sizeof(size_t), 6); + compute_encoder->setBytes(&ndim, sizeof(int), 7); // Select block dimensions @@ -200,7 +189,8 @@ void col_reduce_dispatch( (n_threads_per_output + threadgroup_dim_y - 1) / threadgroup_dim_y; // Launch enough thread groups for each output - MTL::Size grid_dims = MTL::Size(n_threadgroups_x, n_threadgroups_y, 1); + MTL::Size grid_dims = + MTL::Size(n_threadgroups_x, n_threadgroups_y, non_col_reductions); MTL::Size group_dims = MTL::Size(threadgroup_dim_x, threadgroup_dim_y, 1); // We set shared memory to be exploited here for reductions within a @@ -216,60 +206,6 @@ void col_reduce_dispatch( compute_encoder->dispatchThreadgroups(grid_dims, group_dims); } -void general_reduce_dispatch( - const array& in, - array& out, - const std::string& op_name, - const std::vector& axes_, - MTL::ComputeCommandEncoder* compute_encoder, - metal::Device& d) { - bool encode_ndim = true; - std::ostringstream kernel_name; - kernel_name << "general_reduce_" << op_name << type_to_name(in); - - // Check for specialzed kernels for input ndim - if (in.ndim() >= 1 && in.ndim() <= MAX_REDUCE_SPECIALIZED_DIMS) { - kernel_name << "_dim_" << in.ndim(); - encode_ndim = false; - } - auto kernel = d.get_kernel(kernel_name.str()); - size_t in_size = in.size(); - size_t ndim = in.ndim(); - - // We set the reducing strides to 0 to induce collisions for the reduction - std::vector out_strides(ndim); - size_t stride = 1; - for (int i = ndim - 1, j = axes_.size() - 1; i >= 0; --i) { - if (j >= 0 && axes_[j] == i) { - out_strides[i] = 0; - --j; - } else { - out_strides[i] = stride; - stride *= in.shape(i); - } - } - - compute_encoder->setComputePipelineState(kernel); - set_array_buffer(compute_encoder, in, 0); - set_array_buffer(compute_encoder, out, 1); - compute_encoder->setBytes(in.shape().data(), ndim * sizeof(int), 2); - compute_encoder->setBytes(in.strides().data(), ndim * sizeof(size_t), 3); - compute_encoder->setBytes(out_strides.data(), ndim * sizeof(size_t), 4); - if (encode_ndim) { - compute_encoder->setBytes(&ndim, sizeof(size_t), 5); - } - - NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); - if (thread_group_size > in_size) { - thread_group_size = in_size; - } - size_t nthreads = in_size; - - MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); - MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); - compute_encoder->dispatchThreads(grid_dims, group_dims); -} - } // namespace ////////////////////////////////////////////////////////////////////// @@ -278,7 +214,7 @@ void general_reduce_dispatch( void Reduce::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); - auto& in = inputs[0]; + array in = inputs[0]; // TODO: Allow specific row and column reductions with types disabled // due to atomics ? @@ -335,36 +271,46 @@ void Reduce::eval_gpu(const std::vector& inputs, array& out) { // Reduce { - // Check for contiguous data - if (in.size() == in.data_size() && - (in.flags().row_contiguous || in.flags().col_contiguous)) { - // Go to all reduce if reducing over all axes - if (axes_.size() == in.ndim()) { - all_reduce_dispatch(in, out, op_name, compute_encoder, d); - return; - } - // Use specialized kernels if the input is row contiguous and - // the reducing axes can be merged into one - else if ( - in.flags().row_contiguous && in.strides().back() == 1 && - (axes_.back() - axes_.front()) == axes_.size() - 1) { - // If the fastest moving axis is being reduced, go to row reduce - if (axes_[0] == (in.ndim() - axes_.size())) { - row_reduce_dispatch(in, out, op_name, axes_, compute_encoder, d); - return; - } - // Otherwise go to to generalized strided reduce - // Note: bool isn't support here yet due to the use of atomics - // once that is updated, this should be the else condition of this - // branch - else if (in.dtype() != bool_) { - col_reduce_dispatch(in, out, op_name, axes_, compute_encoder, d); - return; - } - } + std::vector copies; + ReductionPlan plan = get_reduction_plan(in, axes_); + + // If it is a general reduce then copy the input to a contiguous array and + // recompute the plan. + if (plan.type == GeneralReduce) { + array in_copy(in.shape(), in.dtype(), nullptr, {}); + copy_gpu(in, in_copy, CopyType::General, s); + copies.push_back(in_copy); + in = in_copy; + plan = get_reduction_plan(in, axes_); + } + + // Reducing over everything and the data is all there no broadcasting or + // slicing etc. + if (plan.type == ContiguousAllReduce) { + all_reduce_dispatch(in, out, op_name, compute_encoder, d); + } + + // At least the last dimension is row contiguous and we are reducing over + // the last dim. + else if ( + plan.type == ContiguousReduce || plan.type == GeneralContiguousReduce) { + row_reduce_general_dispatch( + in, out, op_name, plan, axes_, compute_encoder, d); + } + + // At least the last two dimensions are contiguous and we are doing a + // strided reduce over these. + else if ( + plan.type == ContiguousStridedReduce || + plan.type == GeneralStridedReduce) { + strided_reduce_general_dispatch( + in, out, op_name, plan, axes_, compute_encoder, d); + } + + if (!copies.empty()) { + d.get_command_buffer(s.index)->addCompletedHandler( + [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); } - // Fall back to the general case - general_reduce_dispatch(in, out, op_name, axes_, compute_encoder, d); } } From a123c3c7d2da4332e288eb2036c707bc38a18048 Mon Sep 17 00:00:00 2001 From: __mo_san__ <50895527+m0saan@users.noreply.github.com> Date: Mon, 25 Dec 2023 16:32:53 +0100 Subject: [PATCH 14/37] implement-batch-norm-layer (#217) - Add batch normalization layer --------- Co-authored-by: Robert McCraith Co-authored-by: Awni Hannun --- docs/src/python/nn/layers.rst | 1 + python/mlx/nn/layers/__init__.py | 2 +- python/mlx/nn/layers/dropout.py | 12 +-- python/mlx/nn/layers/normalization.py | 120 ++++++++++++++++++++++ python/mlx/nn/losses.py | 6 +- python/tests/test_nn.py | 137 ++++++++++++++++++++++++++ 6 files changed, 267 insertions(+), 11 deletions(-) diff --git a/docs/src/python/nn/layers.rst b/docs/src/python/nn/layers.rst index 4c7d4aa79..5ef45d60d 100644 --- a/docs/src/python/nn/layers.rst +++ b/docs/src/python/nn/layers.rst @@ -20,6 +20,7 @@ Layers Linear Conv1d Conv2d + BatchNorm LayerNorm RMSNorm GroupNorm diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index d54e45f6d..5ac82356a 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -36,7 +36,7 @@ from mlx.nn.layers.convolution import Conv1d, Conv2d from mlx.nn.layers.dropout import Dropout, Dropout2d from mlx.nn.layers.embedding import Embedding from mlx.nn.layers.linear import Linear -from mlx.nn.layers.normalization import GroupNorm, LayerNorm, RMSNorm +from mlx.nn.layers.normalization import BatchNorm, GroupNorm, LayerNorm, RMSNorm from mlx.nn.layers.positional_encoding import ALiBi, RoPE, SinusoidalPositionalEncoding from mlx.nn.layers.quantized import QuantizedLinear from mlx.nn.layers.transformer import ( diff --git a/python/mlx/nn/layers/dropout.py b/python/mlx/nn/layers/dropout.py index 14c5cb15e..caa7a6452 100644 --- a/python/mlx/nn/layers/dropout.py +++ b/python/mlx/nn/layers/dropout.py @@ -5,7 +5,7 @@ from mlx.nn.layers.base import Module class Dropout(Module): - """Randomly zero a portion of the elements during training. + r"""Randomly zero a portion of the elements during training. The remaining elements are multiplied with :math:`\frac{1}{1-p}` where :math:`p` is the probability of zeroing an element. This is done so the @@ -36,15 +36,13 @@ class Dropout(Module): class Dropout2d(Module): - """Apply 2D channel-wise dropout during training. + r"""Apply 2D channel-wise dropout during training. Randomly zero out entire channels independently with probability :math:`p`. This layer expects the channels to be last, i.e. the input shape should be - ``NWHC`` or ``WHC`` where: - - ``N`` is the batch dimension - - ``H`` is the input image height - - ``W`` is the input image width - - ``C`` is the number of input channels + ``NWHC`` or ``WHC`` where:``N`` is the batch dimension,``H`` is the input + image height,``W`` is the input image width, and``C`` is the number of + input channels The remaining channels are scaled by :math:`\frac{1}{1-p}` to maintain the expected value of each element. Unlike traditional dropout, diff --git a/python/mlx/nn/layers/normalization.py b/python/mlx/nn/layers/normalization.py index 6de377cda..9cd578fb2 100644 --- a/python/mlx/nn/layers/normalization.py +++ b/python/mlx/nn/layers/normalization.py @@ -1,5 +1,7 @@ # Copyright © 2023 Apple Inc. +from typing import Tuple + import mlx.core as mx from mlx.nn.layers.base import Module @@ -178,3 +180,121 @@ class GroupNorm(Module): ) x = group_norm(x) return (self.weight * x + self.bias) if "weight" in self else x + + +class BatchNorm(Module): + r"""Applies Batch Normalization over a 2D or 3D input. + + Computes + + .. math:: + + y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta, + + where :math:`\gamma` and :math:`\beta` are learned per feature dimension + parameters initialized at 1 and 0 respectively. + + The input shape is specified as ``NC`` or ``NLC``, where ``N`` is the + batch, ``C`` is the number of features or channels, and ``L`` is the + sequence length. The output has the same shape as the input. For + four-dimensional arrays, the shape is ``NHWC``, where ``H`` and ``W`` are + the height and width respecitvely. + + For more information on Batch Normalization, see the original paper `Batch + Normalization: Accelerating Deep Network Training by Reducing Internal + Covariate Shift `_. + + Args: + num_features (int): The feature dimension to normalize over. + eps (float, optional): A small additive constant for numerical + stability. Default: ``1e-5``. + momentum (float, optional): The momentum for updating the running + mean and variance. Default: ``0.1``. + affine (bool, optional): If ``True``, apply a learned affine + transformation after the normalization. Default: ``True``. + track_running_stats (bool, optional): If ``True``, track the + running mean and variance. Default: ``True``. + + Examples: + >>> import mlx.core as mx + >>> import mlx.nn as nn + >>> x = mx.random.normal((5, 4)) + >>> bn = nn.BatchNorm(num_features=4, affine=True) + >>> output = bn(x) + """ + + def __init__( + self, + num_features: int, + eps: float = 1e-5, + momentum: float = 0.1, + affine: bool = True, + track_running_stats: bool = True, + ): + super().__init__() + + self.num_features = num_features + self.eps = eps + self.momentum = momentum + self.track_running_stats = track_running_stats + + if affine: + self.weight = mx.ones((num_features,)) + self.bias = mx.zeros((num_features,)) + + if self.track_running_stats: + self._running_mean = mx.zeros((num_features,)) + self._running_var = mx.ones((num_features,)) + + def _extra_repr(self): + return ( + f"{self.num_features}, eps={self.eps}, " + f"momentum={self.momentum}, affine={'weight' in self}, " + f"track_running_stats={self.track_running_stats}" + ) + + def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]: + """ + Calculate the mean and variance of the input tensor. + + Args: + x (mx.array): Input tensor. + + Returns: + tuple: Tuple containing mean and variance. + """ + reduction_axes = tuple(range(0, x.ndim - 1)) + means = mx.mean(x, axis=reduction_axes, keepdims=True) + var = mx.var(x, axis=reduction_axes, keepdims=True) + + if self.track_running_stats and self.training: + self._running_mean = ( + 1 - self.momentum + ) * self._running_mean + self.momentum * means + self._running_var = ( + 1 - self.momentum + ) * self._running_var + self.momentum * var + return means, var + + def __call__(self, x: mx.array) -> mx.array: + """ + Forward pass of BatchNorm. + + Args: + x (mx.array): Input tensor. + + Returns: + mx.array: Output tensor. + """ + + if x.ndim < 2 or x.ndim > 4: + raise ValueError( + f"Expected input tensor to have 2, 3 or 4 dimensions, but got {x.ndim}" + ) + + if self.training or not self.track_running_stats: + means, var = self._calc_stats(x) + else: + means, var = self._running_mean, self._running_var + x = (x - means) * mx.rsqrt(var + self.eps) + return (self.weight * x + self.bias) if "weight" in self else x diff --git a/python/mlx/nn/losses.py b/python/mlx/nn/losses.py index cfb6ffa15..91316fd04 100644 --- a/python/mlx/nn/losses.py +++ b/python/mlx/nn/losses.py @@ -286,7 +286,7 @@ def _reduce(loss: mx.array, reduction: str = "none"): def hinge_loss( inputs: mx.array, targets: mx.array, reduction: str = "none" ) -> mx.array: - """ + r""" Computes the hinge loss between inputs and targets. .. math:: @@ -311,7 +311,7 @@ def hinge_loss( def huber_loss( inputs: mx.array, targets: mx.array, delta: float = 1.0, reduction: str = "none" ) -> mx.array: - """ + r""" Computes the Huber loss between inputs and targets. .. math:: @@ -345,7 +345,7 @@ def huber_loss( def log_cosh_loss( inputs: mx.array, targets: mx.array, reduction: str = "none" ) -> mx.array: - """ + r""" Computes the log cosh loss between inputs and targets. Logcosh acts like L2 loss for small errors, ensuring stable gradients, diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index cc56bc430..2cfac4475 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -320,6 +320,143 @@ class TestNN(mlx_tests.MLXTestCase): self.assertTrue(np.allclose(means, 3 * np.ones_like(means), atol=1e-6)) self.assertTrue(np.allclose(var, 4 * np.ones_like(var), atol=1e-6)) + def test_batch_norm(self): + mx.random.seed(42) + x = mx.random.normal((5, 4), dtype=mx.float32) + + # Batch norm + bn = nn.BatchNorm(num_features=4, affine=True) + self.assertTrue(mx.allclose(bn._running_mean, mx.zeros_like(bn._running_mean))) + self.assertTrue(mx.allclose(bn._running_var, mx.ones_like(bn._running_var))) + y = bn(x) + expected_y = mx.array( + [ + [-0.439520, 1.647328, -0.955515, 1.966031], + [-1.726690, -1.449826, -0.234026, -0.723364], + [0.938414, -0.349603, -0.354470, -0.175369], + [0.305006, 0.234914, -0.393017, -0.459385], + [0.922789, -0.082813, 1.937028, -0.607913], + ], + ) + expected_mean = mx.array([0.008929, 0.005680, -0.016092, 0.027778]) + expected_var = mx.array([0.928435, 1.00455, 1.04117, 0.94258]) + self.assertTrue(x.shape == y.shape) + self.assertTrue(mx.allclose(y, expected_y, atol=1e-5)) + self.assertTrue(mx.allclose(bn._running_mean, expected_mean, atol=1e-5)) + self.assertTrue(mx.allclose(bn._running_var, expected_var, atol=1e-5)) + + # test eval mode + bn.eval() + y = bn(x) + expected_y = mx.array( + [ + [-0.15984, 1.73159, -1.25456, 1.57891], + [-0.872193, -1.4281, -0.414439, -0.228678], + [0.602743, -0.30566, -0.554687, 0.139639], + [0.252199, 0.29066, -0.599572, -0.0512532], + [0.594096, -0.0334829, 2.11359, -0.151081], + ] + ) + + self.assertTrue(x.shape == y.shape) + self.assertTrue(mx.allclose(y, expected_y, atol=1e-5)) + + # test_no_affine + bn = nn.BatchNorm(num_features=4, affine=False) + y = bn(x) + expected_y = mx.array( + [ + [-0.439520, 1.647328, -0.955515, 1.966031], + [-1.726690, -1.449826, -0.234026, -0.723364], + [0.938414, -0.349603, -0.354470, -0.175369], + [0.305006, 0.234914, -0.393017, -0.459385], + [0.922789, -0.082813, 1.937028, -0.607913], + ] + ) + self.assertTrue(x.shape == y.shape) + self.assertTrue(mx.allclose(y, expected_y, atol=1e-5)) + + # test with 3D input + mx.random.seed(42) + N = 2 + L = 4 + C = 5 + x = mx.random.normal((N, L, C), dtype=mx.float32) + + # Batch norm + bn = nn.BatchNorm(num_features=C, affine=True) + self.assertTrue(mx.allclose(bn._running_mean, mx.zeros_like(bn._running_mean))) + self.assertTrue(mx.allclose(bn._running_var, mx.ones_like(bn._running_var))) + y = bn(x) + self.assertTrue(x.shape == y.shape) + expected_y = mx.array( + [ + [ + [-0.335754, 0.342054, 1.02653, 0.628588, -1.63899], + [1.92092, 0.432319, 0.343043, 1.95489, 1.0696], + [-0.853748, 1.3661, 0.868569, 0.0199196, -0.887284], + [0.459206, -0.684822, -0.706354, -0.271531, 0.566341], + ], + [ + [-0.921179, 0.684951, -0.77466, -0.490372, -0.247032], + [1.10839, -2.13179, 0.628924, -1.62639, -0.539708], + [-0.348943, 0.412194, -2.03818, 0.524972, 1.64568], + [-1.02889, -0.421, 0.652127, -0.740079, 0.0313996], + ], + ] + ) + self.assertTrue(mx.allclose(y, expected_y, atol=1e-5)) + expected_mean = mx.array( + [[[0.00207845, -5.3259e-05, 0.04755, -0.0697296, 0.0236228]]] + ) + expected_var = mx.array([[[0.968415, 1.05322, 0.96913, 0.932305, 0.967224]]]) + self.assertTrue(mx.allclose(bn._running_mean, expected_mean, atol=1e-5)) + self.assertTrue(mx.allclose(bn._running_var, expected_var, atol=1e-5)) + + x = mx.random.normal((N, L, C, L, C), dtype=mx.float32) + with self.assertRaises(ValueError): + y = bn(x) + + def test_batch_norm_stats(self): + batch_size = 2 + num_features = 4 + h = 3 + w = 3 + momentum = 0.1 + + batch_norm = nn.BatchNorm(num_features) + + batch_norm.train() + running_mean = np.array(batch_norm._running_mean) + running_var = np.array(batch_norm._running_var) + + data = mx.random.normal((batch_size, num_features)) + + normalized_data = batch_norm(data) + np_data = np.array(data) + means = np.mean(np_data, axis=0) + variances = np.var(np_data, axis=0) + running_mean = (1 - momentum) * running_mean + momentum * means + running_var = (1 - momentum) * running_var + momentum * variances + self.assertTrue(np.allclose(batch_norm._running_mean, running_mean, atol=1e-5)) + self.assertTrue(np.allclose(batch_norm._running_var, running_var, atol=1e-5)) + + batch_norm = nn.BatchNorm(num_features) + + batch_norm.train() + running_mean = np.array(batch_norm._running_mean) + running_var = np.array(batch_norm._running_var) + data = mx.random.normal((batch_size, h, w, num_features)) + + normalized_data = batch_norm(data) + np_data = np.array(data) + means = np.mean(np_data, axis=(0, 1, 2)) + variances = np.var(np_data, axis=(0, 1, 2)) + running_mean = (1 - momentum) * running_mean + momentum * means + running_var = (1 - momentum) * running_var + momentum * variances + self.assertTrue(np.allclose(batch_norm._running_mean, running_mean, atol=1e-5)) + self.assertTrue(np.allclose(batch_norm._running_var, running_var, atol=1e-5)) + def test_conv1d(self): N = 5 L = 12 From d58ac083f3f6903a6115aefb3bb99a8c37cbb9f1 Mon Sep 17 00:00:00 2001 From: Daniel Strobusch <1847260+dastrobu@users.noreply.github.com> Date: Mon, 25 Dec 2023 19:34:28 +0100 Subject: [PATCH 15/37] expose itemsize and nbytes as for numpy arrays (#284) see: * https://numpy.org/doc/stable/reference/generated/numpy.ndarray.nbytes.html * https://numpy.org/doc/stable/reference/generated/numpy.ndarray.itemsize.html relates to https://github.com/ml-explore/mlx-examples/pull/174 --- python/src/array.cpp | 8 ++++++++ python/tests/test_array.py | 2 ++ 2 files changed, 10 insertions(+) diff --git a/python/src/array.cpp b/python/src/array.cpp index 74223cdda..2580de0ea 100644 --- a/python/src/array.cpp +++ b/python/src/array.cpp @@ -510,6 +510,14 @@ void init_array(py::module_& m) { "size", &array::size, R"pbdoc(Number of elments in the array.)pbdoc") .def_property_readonly( "ndim", &array::ndim, R"pbdoc(The array's dimension.)pbdoc") + .def_property_readonly( + "itemsize", + &array::itemsize, + R"pbdoc(The size of the array's datatype in bytes.)pbdoc") + .def_property_readonly( + "nbytes", + &array::nbytes, + R"pbdoc(The number of bytes in the array.)pbdoc") // TODO, this makes a deep copy of the shape // implement alternatives to use reference // https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html diff --git a/python/tests/test_array.py b/python/tests/test_array.py index fb6a24cbc..847d6d142 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -84,6 +84,8 @@ class TestArray(mlx_tests.MLXTestCase): x = mx.array(1) self.assertEqual(x.size, 1) self.assertEqual(x.ndim, 0) + self.assertEqual(x.itemsize, 4) + self.assertEqual(x.nbytes, 4) self.assertEqual(x.shape, []) self.assertEqual(x.dtype, mx.int32) self.assertEqual(x.item(), 1) From fc4e5b476b4e3801d1fb0130404fdea3432bcaf2 Mon Sep 17 00:00:00 2001 From: Yutaka Kondo Date: Mon, 25 Dec 2023 20:53:20 -0800 Subject: [PATCH 16/37] Fix llama link in README.md (#289) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 72022b0a1..0276e5006 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,7 @@ variety of examples, including: - [Transformer language model](https://github.com/ml-explore/mlx-examples/tree/main/transformer_lm) training. - Large-scale text generation with - [LLaMA](https://github.com/ml-explore/mlx-examples/tree/main/llama) and + [LLaMA](https://github.com/ml-explore/mlx-examples/tree/main/llms/llama) and finetuning with [LoRA](https://github.com/ml-explore/mlx-examples/tree/main/lora). - Generating images with [Stable Diffusion](https://github.com/ml-explore/mlx-examples/tree/main/stable_diffusion). - Speech recognition with [OpenAI's Whisper](https://github.com/ml-explore/mlx-examples/tree/main/whisper). From 447bc089b95b8d123daaee5742b5f43c400220d1 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 26 Dec 2023 19:21:05 -0800 Subject: [PATCH 17/37] Fix tolerance in de-/quantization test (#295) --- python/tests/test_quantized.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 5fcc882a5..049f92fdb 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -13,7 +13,8 @@ class TestQuantized(mlx_tests.MLXTestCase): w_q, scales, biases = mx.quantize(w, 64, b) w_hat = mx.dequantize(w_q, scales, biases, 64, b) errors = (w - w_hat).abs().reshape(*scales.shape, -1) - self.assertTrue((errors <= scales[..., None] / 2).all()) + eps = 1e-6 + self.assertTrue((errors <= (scales[..., None] / 2 + eps)).all()) def test_qmm(self): key = mx.random.key(0) From cc9b2dc3c2f1bfb05dd5e2230aba9faf4ece7d47 Mon Sep 17 00:00:00 2001 From: Gabrijel Boduljak Date: Sun, 17 Dec 2023 02:55:33 +0100 Subject: [PATCH 18/37] implemented vector_norm in cpp added linalg to mlx --- mlx/CMakeLists.txt | 1 + mlx/linalg.cpp | 64 ++++++++++++++ mlx/linalg.h | 45 ++++++++++ mlx/mlx.h | 1 + tests/CMakeLists.txt | 1 + tests/linalg_tests.cpp | 189 +++++++++++++++++++++++++++++++++++++++++ 6 files changed, 301 insertions(+) create mode 100644 mlx/linalg.cpp create mode 100644 mlx/linalg.h create mode 100644 tests/linalg_tests.cpp diff --git a/mlx/CMakeLists.txt b/mlx/CMakeLists.txt index bd28537f1..e004fc3d9 100644 --- a/mlx/CMakeLists.txt +++ b/mlx/CMakeLists.txt @@ -14,6 +14,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/scheduler.cpp ${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp ${CMAKE_CURRENT_SOURCE_DIR}/backend/metal/metal.h ) diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp new file mode 100644 index 000000000..f40447954 --- /dev/null +++ b/mlx/linalg.cpp @@ -0,0 +1,64 @@ +// Copyright © 2023 Apple Inc. + +#include +#include +#include +#include + +#include "mlx/array.h" +#include "mlx/linalg.h" +#include "mlx/ops.h" + +namespace mlx::core::linalg { + +array vector_norm( + const array& a, + const std::variant& ord, + const std::vector& axes, + bool keepdims, + StreamOrDevice s) { + return std::visit( + overloaded{ + [&](double p) { + if (p >= 1) + return power( + sum(power(abs(a, s), array(p), s), axes, keepdims, s), + array(1.0 / p), + s); + else if (p == 0) + return sum( + where(a != 0, array(1), array(0), s), axes, keepdims, s); + else + throw std::invalid_argument( + "[core.linalg.norm] p norm is defined only for p >= 1."); + }, + [&](const std::string& norm_type) { + if (norm_type == "inf") + return max(abs(a, s), axes, keepdims, s); + else if (norm_type == "-inf") + return min(abs(a, s), axes, keepdims, s); + else + throw std::invalid_argument( + "[core.linalg.norm] Unsupported norm type for a vector."); + }}, + ord); +} +array vector_norm( + const array& a, + const std::variant& ord, + bool keepdims, + StreamOrDevice s) { + return vector_norm( + reshape(a, {static_cast(a.size())}), ord, {-1}, keepdims, s); +} +array vector_norm( + const array& a, + const std::vector& axes, + bool keepdims, + StreamOrDevice s) { + return vector_norm(a, 2.0, axes, keepdims, s); +} +array vector_norm(const array& a, bool keepdims, StreamOrDevice s) { + return vector_norm(a, 2.0, keepdims, s); +} +} // namespace mlx::core::linalg \ No newline at end of file diff --git a/mlx/linalg.h b/mlx/linalg.h new file mode 100644 index 000000000..dc7d8d29d --- /dev/null +++ b/mlx/linalg.h @@ -0,0 +1,45 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include + +#include "array.h" +#include "device.h" +#include "ops.h" +#include "stream.h" +#include "string.h" + +namespace mlx::core::linalg { + +template +struct overloaded : Ts... { + using Ts::operator()...; +}; +template +overloaded(Ts...) -> overloaded; + +/* +Computes a vector norm. + If axes = {}, x will be flattened before the norm is computed. + Otherwise, the norm is computed over axes and the other dimensions are +treated as batch dimensions. +*/ +array vector_norm( + const array& a, + const std::variant& ord = 2.0, + const std::vector& axes = {}, + bool keepdims = false, + StreamOrDevice s = {}); +array vector_norm( + const array& a, + const std::variant& ord = 2.0, + bool keepdims = false, + StreamOrDevice s = {}); +array vector_norm( + const array& a, + const std::vector& axes = {}, + bool keepdims = false, + StreamOrDevice s = {}); +array vector_norm(const array& a, bool keepdims = false, StreamOrDevice s = {}); +} // namespace mlx::core::linalg \ No newline at end of file diff --git a/mlx/mlx.h b/mlx/mlx.h index 102d2dde9..8d785c39f 100644 --- a/mlx/mlx.h +++ b/mlx/mlx.h @@ -6,6 +6,7 @@ #include "mlx/backend/metal/metal.h" #include "mlx/device.h" #include "mlx/fft.h" +#include "mlx/linalg.h" #include "mlx/ops.h" #include "mlx/random.h" #include "mlx/stream.h" diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 0879aa0f6..dbc499205 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -31,6 +31,7 @@ target_sources(tests PRIVATE scheduler_tests.cpp utils_tests.cpp vmap_tests.cpp + linalg_tests.cpp ${METAL_TEST_SOURCES} ) diff --git a/tests/linalg_tests.cpp b/tests/linalg_tests.cpp new file mode 100644 index 000000000..6cd74357b --- /dev/null +++ b/tests/linalg_tests.cpp @@ -0,0 +1,189 @@ +// Copyright © 2023 Apple Inc. + +#include "doctest/doctest.h" + +#include +#include +#include "mlx/linalg.h" +#include "mlx/mlx.h" + +using namespace mlx::core; +using namespace mlx::core::linalg; + +TEST_CASE("vector_norm") { + // Test 1-norm on a vector + CHECK( + array_equal(vector_norm(ones({3}), 1.0, false), array(3.0)).item()); + CHECK(array_equal(vector_norm(ones({3}), 1.0, true), array({3.0})) + .item()); + // Test 1-norm on a matrix + CHECK(array_equal( + vector_norm(reshape(arange(9), {3, 3}), 1.0, false), array(36)) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(9), {3, 3}), 1.0, true), array({36})) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(9), {3, 3}), 1.0, {0, 1}, false), + array(36)) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(9), {3, 3}), 1.0, {0, 1}, true), + array({36}, {1, 1})) + .item()); + // Over columns + CHECK(array_equal( + vector_norm(reshape(arange(9), {3, 3}), 1.0, {1}, false), + array({3, 12, 21})) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(9), {3, 3}), 1.0, {1}, true), + array({3, 12, 21}, {3, 1})) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(9), {3, 3}), 1.0, {-1}, false), + array({3, 12, 21})) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(9), {3, 3}), 1.0, {-1}, true), + array({3, 12, 21}, {3, 1})) + .item()); + // Over rows + CHECK(array_equal( + vector_norm(reshape(arange(9), {3, 3}), 1.0, {0}, false), + array({9, 12, 15})) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(9), {3, 3}), 1.0, {0}, true), + array({9, 12, 15}, {1, 3})) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(9), {3, 3}), 1.0, {-2}, false), + array({9, 12, 15})) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(9), {3, 3}), 1.0, {-2}, true), + array({9, 12, 15}, {1, 3})) + .item()); + // Test 1-norm on a 3d tensor + CHECK(array_equal( + vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, false), array(153)) + .item()); + CHECK( + array_equal( + vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, true), array({153})) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {0, 1, 2}, false), + array(153)) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {0, 1, 2}, true), + array({153}, {1, 1, 1})) + .item()); + // Over last axis + CHECK(array_equal( + vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {2}, false), + array({3, 12, 21, 30, 39, 48}, {2, 3})) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {2}, true), + array({3, 12, 21, 30, 39, 48}, {2, 3, 1})) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-1}, false), + array({3, 12, 21, 30, 39, 48}, {2, 3})) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-1}, true), + array({3, 12, 21, 30, 39, 48}, {2, 3, 1})) + .item()); + // Over middle axis + CHECK(array_equal( + vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {1}, false), + array({9, 12, 15, 36, 39, 42}, {2, 3})) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {1}, true), + array({9, 12, 15, 36, 39, 42}, {2, 1, 3})) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-2}, false), + array({9, 12, 15, 36, 39, 42}, {2, 3})) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-2}, true), + array({9, 12, 15, 36, 39, 42}, {2, 1, 3})) + .item()); + // Over the first axis + CHECK(array_equal( + vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {0}, false), + array({9, 11, 13, 15, 17, 19, 21, 23, 25}, {3, 3})) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {0}, true), + array({9, 11, 13, 15, 17, 19, 21, 23, 25}, {1, 3, 3})) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-3}, false), + array({9, 11, 13, 15, 17, 19, 21, 23, 25}, {3, 3})) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-3}, true), + array({9, 11, 13, 15, 17, 19, 21, 23, 25}, {1, 3, 3})) + .item()); + // Test 2-norm on a vector + CHECK(array_equal(vector_norm({3.0, 4.0}, 2.0, false), array(5.0)) + .item()); + CHECK(array_equal(vector_norm({3.0, 4.0}, 2.0, true), array({5.0})) + .item()); + // Test that 2 is default ord + CHECK(array_equal(vector_norm({3.0, 4.0}, false), array(5.0)).item()); + CHECK(array_equal(vector_norm({3.0, 4.0}, true), array({5.0})).item()); + // Test "inf" norm on a matrix + CHECK(array_equal( + vector_norm(reshape(arange(9), {3, 3}), "inf", false), array(8.0)) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(9), {3, 3}), "inf", true), array({8.0})) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(9), {3, 3}), "inf", {1}, false), + array({2, 5, 8})) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(9), {3, 3}), "inf", {1}, true), + array({2, 5, 8}, {3, 1})) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(9), {3, 3}), "inf", {0}, false), + array({6.0, 7.0, 8.0})) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(9), {3, 3}), "inf", {0}, true), + array({6, 7, 8}, {1, 3})) + .item()); + // Test "-inf" norm on a matrix + CHECK(array_equal( + vector_norm(reshape(arange(9), {3, 3}), "-inf", false), array(0)) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(9), {3, 3}), "-inf", true), array({0})) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(9), {3, 3}), "-inf", {1}, false), + array({0, 3, 6})) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(9), {3, 3}), "-inf", {1}, true), + array({0, 3, 6}, {3, 1})) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(9), {3, 3}), "-inf", {0}, false), + array({0, 1, 2})) + .item()); + CHECK(array_equal( + vector_norm(reshape(arange(9), {3, 3}), "-inf", {0}, true), + array({0, 1, 2}, {1, 3})) + .item()); +} \ No newline at end of file From 24da85025f1bc0d5987eb090ccfdcb21c3ba7330 Mon Sep 17 00:00:00 2001 From: Gabrijel Boduljak Date: Sun, 17 Dec 2023 07:06:04 +0100 Subject: [PATCH 19/37] implemented vector_norm python binding --- python/src/CMakeLists.txt | 1 + python/src/linalg.cpp | 88 +++++++++++++++++++++++++++++++++++++++ python/src/mlx.cpp | 2 + 3 files changed, 91 insertions(+) create mode 100644 python/src/linalg.cpp diff --git a/python/src/CMakeLists.txt b/python/src/CMakeLists.txt index 5ab8a50bf..1ad9d207d 100644 --- a/python/src/CMakeLists.txt +++ b/python/src/CMakeLists.txt @@ -11,6 +11,7 @@ pybind11_add_module( ${CMAKE_CURRENT_SOURCE_DIR}/stream.cpp ${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp ${CMAKE_CURRENT_SOURCE_DIR}/random.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/linalg.cpp ) if (NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY) diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp new file mode 100644 index 000000000..b389ac15f --- /dev/null +++ b/python/src/linalg.cpp @@ -0,0 +1,88 @@ + +// Copyright © 2023 Apple Inc. + +#include +#include +#include + +#include +#include +#include + +#include "mlx/linalg.h" +#include "mlx/ops.h" +#include "mlx/utils.h" + +#include "python/src/load.h" +#include "python/src/utils.h" + +namespace py = pybind11; +using namespace py::literals; + +using namespace mlx::core; +using namespace mlx::core::linalg; + +void init_linalg(py::module_& parent_module) { + auto m = + parent_module.def_submodule("linalg", "mlx.core.linalg: Linear Algebra."); + + m.def( + "vector_norm", + [](const array& a, + const std::variant& ord, + const std::variant>& axis, + bool keepdims, + StreamOrDevice s) { + std::vector axes = std::visit( + overloaded{ + [](std::monostate s) { return std::vector(); }, + [](int axis) { return std::vector({axis}); }, + [](const std::vector axes) { return axes; }}, + axis); + + if (axes.empty()) + return vector_norm(a, ord, keepdims, s); + else + return vector_norm(a, ord, axes, keepdims, s); + }, + "a"_a, + "ord"_a = 2.0, + "axis"_a = none, + "keepdims"_a = false, + "stream"_a = none, + R"pbdoc( + Computes a vector norm. + + - If :attr:`axis`\ `= None`, :attr:`a` will be flattened before the norm is computed. + - If :attr:`axis` is an `int` or a `tuple`, the norm will be computed over these dimensions + and the other dimensions will be treated as batch dimensions. + + + :attr:`ord` defines the vector norm that is computed. The following norms are supported: + + ====================== =============================== + :attr:`ord` vector norm + ====================== =============================== + `2` (default) `2`-norm (see below) + `inf` `max(abs(x))` + `-inf` `min(abs(x))` + `0` `sum(x != 0)` + other `int` or `float` `sum(abs(x)^{ord})^{(1 / ord)}` + ====================== =============================== + + where `inf` refers to `float('inf')`, NumPy's `inf` object, or any equivalent object. + + Args: + a (Tensor): tensor, flattened by default, but this behavior can be + controlled using :attr:`dim`. + ord (int, float, inf, -inf, 'fro', 'nuc', optional): order of norm. Default: `2` + axis (int, Tuple[int], optional): dimensions over which to compute + the norm. See above for the behavior when :attr:`dim`\ `= None`. + Default: `None` + keepdims (bool, optional): If set to `True`, the reduced dimensions are retained + in the result as dimensions with size one. Default: `False` + + Returns: + A real-valued tensor, even when :attr:`a` is complex. + )pbdoc"); +} diff --git a/python/src/mlx.cpp b/python/src/mlx.cpp index ebadf767d..d7cf15751 100644 --- a/python/src/mlx.cpp +++ b/python/src/mlx.cpp @@ -15,6 +15,7 @@ void init_ops(py::module_&); void init_transforms(py::module_&); void init_random(py::module_&); void init_fft(py::module_&); +void init_linalg(py::module_&); PYBIND11_MODULE(core, m) { m.doc() = "mlx: A framework for machine learning on Apple silicon."; @@ -29,5 +30,6 @@ PYBIND11_MODULE(core, m) { init_transforms(m); init_random(m); init_fft(m); + init_linalg(m); m.attr("__version__") = TOSTRING(_VERSION_); } From 05203ecd78004d064ce32f06b8c1855cef4b2a34 Mon Sep 17 00:00:00 2001 From: Gabrijel Boduljak Date: Wed, 20 Dec 2023 03:13:18 +0100 Subject: [PATCH 20/37] renamed vector_norm to norm, implemented norm without provided ord --- mlx/linalg.cpp | 69 +++++-------- mlx/linalg.h | 30 +----- tests/linalg_tests.cpp | 228 ++++++++++++----------------------------- 3 files changed, 93 insertions(+), 234 deletions(-) diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index f40447954..b49713afa 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -3,62 +3,39 @@ #include #include #include +#include #include +#include #include "mlx/array.h" #include "mlx/linalg.h" #include "mlx/ops.h" +#include "utils.h" namespace mlx::core::linalg { -array vector_norm( +inline std::vector get_shape_reducing_over_all_dims(int num_axes) { + std::vector shape(num_axes); + std::iota(shape.begin(), shape.end(), 0); + return shape; +} + +array norm( const array& a, - const std::variant& ord, - const std::vector& axes, + const std::vector& axis, bool keepdims, StreamOrDevice s) { - return std::visit( - overloaded{ - [&](double p) { - if (p >= 1) - return power( - sum(power(abs(a, s), array(p), s), axes, keepdims, s), - array(1.0 / p), - s); - else if (p == 0) - return sum( - where(a != 0, array(1), array(0), s), axes, keepdims, s); - else - throw std::invalid_argument( - "[core.linalg.norm] p norm is defined only for p >= 1."); - }, - [&](const std::string& norm_type) { - if (norm_type == "inf") - return max(abs(a, s), axes, keepdims, s); - else if (norm_type == "-inf") - return min(abs(a, s), axes, keepdims, s); - else - throw std::invalid_argument( - "[core.linalg.norm] Unsupported norm type for a vector."); - }}, - ord); -} -array vector_norm( - const array& a, - const std::variant& ord, - bool keepdims, - StreamOrDevice s) { - return vector_norm( - reshape(a, {static_cast(a.size())}), ord, {-1}, keepdims, s); -} -array vector_norm( - const array& a, - const std::vector& axes, - bool keepdims, - StreamOrDevice s) { - return vector_norm(a, 2.0, axes, keepdims, s); -} -array vector_norm(const array& a, bool keepdims, StreamOrDevice s) { - return vector_norm(a, 2.0, keepdims, s); + auto num_axes = axis.size(); + + if (num_axes == 0 || num_axes == 1 || num_axes == 2) + return sqrt(sum( + abs(a, s) * abs(a, s), + num_axes ? axis : get_shape_reducing_over_all_dims(a.shape().size()), + keepdims, + s)); + + std::stringstream error_stream; + error_stream << "Invalid axis values" << axis; + throw std::invalid_argument(error_stream.str()); } } // namespace mlx::core::linalg \ No newline at end of file diff --git a/mlx/linalg.h b/mlx/linalg.h index dc7d8d29d..fa9658bbb 100644 --- a/mlx/linalg.h +++ b/mlx/linalg.h @@ -11,35 +11,9 @@ #include "string.h" namespace mlx::core::linalg { - -template -struct overloaded : Ts... { - using Ts::operator()...; -}; -template -overloaded(Ts...) -> overloaded; - -/* -Computes a vector norm. - If axes = {}, x will be flattened before the norm is computed. - Otherwise, the norm is computed over axes and the other dimensions are -treated as batch dimensions. -*/ -array vector_norm( +array norm( const array& a, - const std::variant& ord = 2.0, - const std::vector& axes = {}, + const std::vector& axis = {}, bool keepdims = false, StreamOrDevice s = {}); -array vector_norm( - const array& a, - const std::variant& ord = 2.0, - bool keepdims = false, - StreamOrDevice s = {}); -array vector_norm( - const array& a, - const std::vector& axes = {}, - bool keepdims = false, - StreamOrDevice s = {}); -array vector_norm(const array& a, bool keepdims = false, StreamOrDevice s = {}); } // namespace mlx::core::linalg \ No newline at end of file diff --git a/tests/linalg_tests.cpp b/tests/linalg_tests.cpp index 6cd74357b..0b6b31801 100644 --- a/tests/linalg_tests.cpp +++ b/tests/linalg_tests.cpp @@ -10,180 +10,88 @@ using namespace mlx::core; using namespace mlx::core::linalg; -TEST_CASE("vector_norm") { - // Test 1-norm on a vector - CHECK( - array_equal(vector_norm(ones({3}), 1.0, false), array(3.0)).item()); - CHECK(array_equal(vector_norm(ones({3}), 1.0, true), array({3.0})) - .item()); - // Test 1-norm on a matrix +TEST_CASE("[mlx.core.linalg.norm] no ord") { + array arr_one_d({1, 2, 3}); + array arr_two_d = reshape(arange(9), {3, 3}); + array arr_three_d = reshape(arange(18), {2, 3, 3}); + + CHECK(array_equal(norm(arr_one_d), array(sqrt(1 + 4 + 9))).item()); + CHECK(array_equal(norm(arr_one_d, {0}), array(sqrt(1 + 4 + 9))).item()); CHECK(array_equal( - vector_norm(reshape(arange(9), {3, 3}), 1.0, false), array(36)) + norm(arr_two_d), + array(sqrt( + 0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + 8 * 8))) .item()); CHECK(array_equal( - vector_norm(reshape(arange(9), {3, 3}), 1.0, true), array({36})) + norm(arr_two_d, {0}), + array( + {sqrt(0 + 3 * 3 + 6 * 6), + sqrt(1 + 4 * 4 + 7 * 7), + sqrt(2 * 2 + 5 * 5 + 8 * 8)})) .item()); CHECK(array_equal( - vector_norm(reshape(arange(9), {3, 3}), 1.0, {0, 1}, false), - array(36)) + norm(arr_two_d, {1}), + array( + {sqrt(0 + 1 + 2 * 2), + sqrt(3 * 3 + 4 * 4 + 5 * 5), + sqrt(6 * 6 + 7 * 7 + 8 * 8)})) .item()); CHECK(array_equal( - vector_norm(reshape(arange(9), {3, 3}), 1.0, {0, 1}, true), - array({36}, {1, 1})) - .item()); - // Over columns - CHECK(array_equal( - vector_norm(reshape(arange(9), {3, 3}), 1.0, {1}, false), - array({3, 12, 21})) + norm(arr_two_d, {0, 1}), + array(sqrt( + 0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + 8 * 8))) .item()); CHECK(array_equal( - vector_norm(reshape(arange(9), {3, 3}), 1.0, {1}, true), - array({3, 12, 21}, {3, 1})) + norm(arr_three_d, {2}), + array( + { + sqrt(0 + 1 + 2 * 2), + sqrt(3 * 3 + 4 * 4 + 5 * 5), + sqrt(6 * 6 + 7 * 7 + 8 * 8), + sqrt(9 * 9 + 10 * 10 + 11 * 11), + sqrt(12 * 12 + 13 * 13 + 14 * 14), + sqrt(15 * 15 + 16 * 16 + 17 * 17), + }, + {2, 3})) .item()); CHECK(array_equal( - vector_norm(reshape(arange(9), {3, 3}), 1.0, {-1}, false), - array({3, 12, 21})) + norm(arr_three_d, {1}), + array( + { + sqrt(0 + 3 * 3 + 6 * 6), + sqrt(1 + 4 * 4 + 7 * 7), + sqrt(2 * 2 + 5 * 5 + 8 * 8), + sqrt(9 * 9 + 12 * 12 + 15 * 15), + sqrt(10 * 10 + 13 * 13 + 16 * 16), + sqrt(11 * 11 + 14 * 14 + 17 * 17), + }, + {2, 3})) .item()); CHECK(array_equal( - vector_norm(reshape(arange(9), {3, 3}), 1.0, {-1}, true), - array({3, 12, 21}, {3, 1})) - .item()); - // Over rows - CHECK(array_equal( - vector_norm(reshape(arange(9), {3, 3}), 1.0, {0}, false), - array({9, 12, 15})) + norm(arr_three_d, {0}), + array( + { + sqrt(0 + 9 * 9), + sqrt(1 + 10 * 10), + sqrt(2 * 2 + 11 * 11), + sqrt(3 * 3 + 12 * 12), + sqrt(4 * 4 + 13 * 13), + sqrt(5 * 5 + 14 * 14), + sqrt(6 * 6 + 15 * 15), + sqrt(7 * 7 + 16 * 16), + sqrt(8 * 8 + 17 * 17), + }, + {3, 3})) .item()); CHECK(array_equal( - vector_norm(reshape(arange(9), {3, 3}), 1.0, {0}, true), - array({9, 12, 15}, {1, 3})) - .item()); - CHECK(array_equal( - vector_norm(reshape(arange(9), {3, 3}), 1.0, {-2}, false), - array({9, 12, 15})) - .item()); - CHECK(array_equal( - vector_norm(reshape(arange(9), {3, 3}), 1.0, {-2}, true), - array({9, 12, 15}, {1, 3})) - .item()); - // Test 1-norm on a 3d tensor - CHECK(array_equal( - vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, false), array(153)) - .item()); - CHECK( - array_equal( - vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, true), array({153})) - .item()); - CHECK(array_equal( - vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {0, 1, 2}, false), - array(153)) - .item()); - CHECK(array_equal( - vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {0, 1, 2}, true), - array({153}, {1, 1, 1})) - .item()); - // Over last axis - CHECK(array_equal( - vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {2}, false), - array({3, 12, 21, 30, 39, 48}, {2, 3})) - .item()); - CHECK(array_equal( - vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {2}, true), - array({3, 12, 21, 30, 39, 48}, {2, 3, 1})) - .item()); - CHECK(array_equal( - vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-1}, false), - array({3, 12, 21, 30, 39, 48}, {2, 3})) - .item()); - CHECK(array_equal( - vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-1}, true), - array({3, 12, 21, 30, 39, 48}, {2, 3, 1})) - .item()); - // Over middle axis - CHECK(array_equal( - vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {1}, false), - array({9, 12, 15, 36, 39, 42}, {2, 3})) - .item()); - CHECK(array_equal( - vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {1}, true), - array({9, 12, 15, 36, 39, 42}, {2, 1, 3})) - .item()); - CHECK(array_equal( - vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-2}, false), - array({9, 12, 15, 36, 39, 42}, {2, 3})) - .item()); - CHECK(array_equal( - vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-2}, true), - array({9, 12, 15, 36, 39, 42}, {2, 1, 3})) - .item()); - // Over the first axis - CHECK(array_equal( - vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {0}, false), - array({9, 11, 13, 15, 17, 19, 21, 23, 25}, {3, 3})) - .item()); - CHECK(array_equal( - vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {0}, true), - array({9, 11, 13, 15, 17, 19, 21, 23, 25}, {1, 3, 3})) - .item()); - CHECK(array_equal( - vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-3}, false), - array({9, 11, 13, 15, 17, 19, 21, 23, 25}, {3, 3})) - .item()); - CHECK(array_equal( - vector_norm(reshape(arange(18), {2, 3, 3}), 1.0, {-3}, true), - array({9, 11, 13, 15, 17, 19, 21, 23, 25}, {1, 3, 3})) - .item()); - // Test 2-norm on a vector - CHECK(array_equal(vector_norm({3.0, 4.0}, 2.0, false), array(5.0)) - .item()); - CHECK(array_equal(vector_norm({3.0, 4.0}, 2.0, true), array({5.0})) - .item()); - // Test that 2 is default ord - CHECK(array_equal(vector_norm({3.0, 4.0}, false), array(5.0)).item()); - CHECK(array_equal(vector_norm({3.0, 4.0}, true), array({5.0})).item()); - // Test "inf" norm on a matrix - CHECK(array_equal( - vector_norm(reshape(arange(9), {3, 3}), "inf", false), array(8.0)) - .item()); - CHECK(array_equal( - vector_norm(reshape(arange(9), {3, 3}), "inf", true), array({8.0})) - .item()); - CHECK(array_equal( - vector_norm(reshape(arange(9), {3, 3}), "inf", {1}, false), - array({2, 5, 8})) - .item()); - CHECK(array_equal( - vector_norm(reshape(arange(9), {3, 3}), "inf", {1}, true), - array({2, 5, 8}, {3, 1})) - .item()); - CHECK(array_equal( - vector_norm(reshape(arange(9), {3, 3}), "inf", {0}, false), - array({6.0, 7.0, 8.0})) - .item()); - CHECK(array_equal( - vector_norm(reshape(arange(9), {3, 3}), "inf", {0}, true), - array({6, 7, 8}, {1, 3})) - .item()); - // Test "-inf" norm on a matrix - CHECK(array_equal( - vector_norm(reshape(arange(9), {3, 3}), "-inf", false), array(0)) - .item()); - CHECK(array_equal( - vector_norm(reshape(arange(9), {3, 3}), "-inf", true), array({0})) - .item()); - CHECK(array_equal( - vector_norm(reshape(arange(9), {3, 3}), "-inf", {1}, false), - array({0, 3, 6})) - .item()); - CHECK(array_equal( - vector_norm(reshape(arange(9), {3, 3}), "-inf", {1}, true), - array({0, 3, 6}, {3, 1})) - .item()); - CHECK(array_equal( - vector_norm(reshape(arange(9), {3, 3}), "-inf", {0}, false), - array({0, 1, 2})) - .item()); - CHECK(array_equal( - vector_norm(reshape(arange(9), {3, 3}), "-inf", {0}, true), - array({0, 1, 2}, {1, 3})) + norm(arr_three_d, {1, 2}), + array( + {sqrt( + 0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + + 8 * 8), + sqrt( + 9 * 9 + 10 * 10 + 11 * 11 + 12 * 12 + 13 * 13 + 14 * 14 + + 15 * 15 + 16 * 16 + 17 * 17)}, + {2})) .item()); } \ No newline at end of file From 8c43d820d99af8cc7b273a08fa801970b3494e42 Mon Sep 17 00:00:00 2001 From: Gabrijel Boduljak Date: Thu, 21 Dec 2023 18:33:23 +0100 Subject: [PATCH 21/37] completed the implementation of the norm --- mlx/linalg.cpp | 140 +++++++++++++++++++++++++++++++++++++++++++++---- mlx/linalg.h | 12 +++++ mlx/utils.cpp | 13 +++++ mlx/utils.h | 5 ++ 4 files changed, 159 insertions(+), 11 deletions(-) diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index b49713afa..1847896d2 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -4,7 +4,7 @@ #include #include #include -#include +#include #include #include "mlx/array.h" @@ -14,10 +14,77 @@ namespace mlx::core::linalg { -inline std::vector get_shape_reducing_over_all_dims(int num_axes) { - std::vector shape(num_axes); - std::iota(shape.begin(), shape.end(), 0); - return shape; +inline array vector_norm( + const array& a, + const double ord, + const std::vector& axis, + bool keepdims, + StreamOrDevice s) { + if (ord == 0.0) + return sum(a != 0, axis, keepdims, s); + else if (ord == 1.0) + return sum(abs(a, s), axis, keepdims, s); + else if (ord == 2.0) + return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s)); + else + return power( + sum(power(abs(a, s), array(ord), s), axis, keepdims, s), + array(1.0 / ord)); +} + +inline array vector_norm( + const array& a, + const std::string& ord, + const std::vector& axis, + bool keepdims, + StreamOrDevice s) { + if (ord == "inf") + return max(abs(a, s), axis, keepdims, s); + else if (ord == "-inf") + return min(abs(a, s), axis, keepdims, s); + std::stringstream error_stream; + error_stream << "Invalid ord value " << ord; + throw std::invalid_argument(error_stream.str()); +} + +inline array matrix_norm( + const array& a, + const double ord, + const std::vector& axis, + bool keepdims, + StreamOrDevice s) { + auto row_axis = axis[0]; + auto col_axis = axis[1]; + if (!keepdims && col_axis > row_axis) + col_axis -= 1; + if (ord == -1.0) + return min(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s); + if (ord == 1.0) + return max(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s); + if (ord == 2.0 || ord == -2.0) + throw std::logic_error("Singular value norms are not implemented."); + std::stringstream error_stream; + error_stream << "Invalid ord value " << ord << " for matrix norm"; + throw std::invalid_argument(error_stream.str()); +} + +inline array matrix_norm( + const array& a, + const std::string& ord, + const std::vector& axis, + bool keepdims, + StreamOrDevice s) { + if (ord == "f" || ord == "fro") + return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s)); + else if (ord == "inf") + return matrix_norm(a, 1.0, {axis[1], axis[0]}, keepdims, s); + else if (ord == "-inf") + return matrix_norm(a, -1.0, {axis[1], axis[0]}, keepdims, s); + if (ord == "nuc") + throw std::logic_error("Nuclear norm is not implemented."); + std::stringstream error_stream; + error_stream << "Invalid ord value " << ord << " for matrix norm"; + throw std::invalid_argument(error_stream.str()); } array norm( @@ -28,14 +95,65 @@ array norm( auto num_axes = axis.size(); if (num_axes == 0 || num_axes == 1 || num_axes == 2) - return sqrt(sum( - abs(a, s) * abs(a, s), - num_axes ? axis : get_shape_reducing_over_all_dims(a.shape().size()), - keepdims, - s)); + return sqrt( + sum(abs(a, s) * abs(a, s), + num_axes ? axis + : get_shape_reducing_over_all_axes(a.shape().size()), + keepdims, + s), + s); std::stringstream error_stream; - error_stream << "Invalid axis values" << axis; + error_stream << "Invalid axis values " << axis; throw std::invalid_argument(error_stream.str()); } + +array norm( + const array& a, + const double ord, + const std::vector& axis, + bool keepdims, + StreamOrDevice s) { + std::vector ax = axis; + + if (axis.empty()) + ax = get_shape_reducing_over_all_axes(a.ndim()); + else + ax = normalize_axes(ax, a.ndim()); + + auto num_axes = ax.size(); + if (num_axes == 1) + return vector_norm(a, ord, ax, keepdims, s); + else if (num_axes == 2) + return matrix_norm(a, ord, ax, keepdims, s); + + std::stringstream error_stream; + error_stream << "Invalid axis values " << ax; + throw std::invalid_argument(error_stream.str()); +} + +array norm( + const array& a, + const std::string& ord, + const std::vector& axis, + bool keepdims, + StreamOrDevice s) { + std::vector ax = axis; + + if (axis.empty()) + ax = get_shape_reducing_over_all_axes(a.ndim()); + else + ax = normalize_axes(ax, a.ndim()); + + auto num_axes = ax.size(); + if (num_axes == 1) + return vector_norm(a, ord, ax, keepdims, s); + else if (num_axes == 2) + return matrix_norm(a, ord, ax, keepdims, s); + + std::stringstream error_stream; + error_stream << "Invalid axis values " << ax; + throw std::invalid_argument(error_stream.str()); +} + } // namespace mlx::core::linalg \ No newline at end of file diff --git a/mlx/linalg.h b/mlx/linalg.h index fa9658bbb..690df343c 100644 --- a/mlx/linalg.h +++ b/mlx/linalg.h @@ -11,6 +11,18 @@ #include "string.h" namespace mlx::core::linalg { +array norm( + const array& a, + const double ord, + const std::vector& axis = {}, + bool keepdims = false, + StreamOrDevice s = {}); +array norm( + const array& a, + const std::string& ord, + const std::vector& axis = {}, + bool keepdims = false, + StreamOrDevice s = {}); array norm( const array& a, const std::vector& axis = {}, diff --git a/mlx/utils.cpp b/mlx/utils.cpp index 1fbc67c8e..ddcb41ba8 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -1,5 +1,6 @@ // Copyright © 2023 Apple Inc. +#include #include #include @@ -73,6 +74,12 @@ int normalize_axis(int axis, int ndim) { } return axis; } +std::vector normalize_axes(const std::vector& axes, int ndim) { + std::vector canonical; + for (int ax : axes) + canonical.push_back(normalize_axis(ax, ndim)); + return canonical; +} std::ostream& operator<<(std::ostream& os, const Device& d) { os << "Device("; @@ -279,4 +286,10 @@ std::ostream& operator<<(std::ostream& os, const std::vector& v) { return os; } +std::vector get_shape_reducing_over_all_axes(int ndim) { + std::vector shape(ndim); + std::iota(shape.begin(), shape.end(), 0); + return shape; +} + } // namespace mlx::core diff --git a/mlx/utils.h b/mlx/utils.h index 823b4c872..1158b7c42 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -24,6 +24,7 @@ bool is_same_shape(const std::vector& arrays); * https://numpy.org/devdocs/reference/generated/numpy.lib.array_utils.normalize_axis_index.html */ int normalize_axis(int axis, int ndim); +std::vector normalize_axes(const std::vector& axes, int ndim); std::ostream& operator<<(std::ostream& os, const Device& d); std::ostream& operator<<(std::ostream& os, const Stream& s); @@ -41,4 +42,8 @@ inline std::ostream& operator<<(std::ostream& os, const float16_t& v) { inline std::ostream& operator<<(std::ostream& os, const bfloat16_t& v) { return os << static_cast(v); } +/** + * Returns the axes vector [0, 1, ... ndim). + */ +std::vector get_shape_reducing_over_all_axes(int ndim); } // namespace mlx::core From 5d7a06717c0d10f9b409faaf572de97dc53756b9 Mon Sep 17 00:00:00 2001 From: Gabrijel Boduljak Date: Thu, 21 Dec 2023 18:34:02 +0100 Subject: [PATCH 22/37] added tests --- tests/linalg_tests.cpp | 227 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 218 insertions(+), 9 deletions(-) diff --git a/tests/linalg_tests.cpp b/tests/linalg_tests.cpp index 0b6b31801..9841f03bf 100644 --- a/tests/linalg_tests.cpp +++ b/tests/linalg_tests.cpp @@ -16,33 +16,34 @@ TEST_CASE("[mlx.core.linalg.norm] no ord") { array arr_three_d = reshape(arange(18), {2, 3, 3}); CHECK(array_equal(norm(arr_one_d), array(sqrt(1 + 4 + 9))).item()); - CHECK(array_equal(norm(arr_one_d, {0}), array(sqrt(1 + 4 + 9))).item()); + CHECK(array_equal(norm(arr_one_d, {0}, false), array(sqrt(1 + 4 + 9))) + .item()); CHECK(array_equal( - norm(arr_two_d), + norm(arr_two_d, {}, false), array(sqrt( 0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + 8 * 8))) .item()); CHECK(array_equal( - norm(arr_two_d, {0}), + norm(arr_two_d, {0}, false), array( {sqrt(0 + 3 * 3 + 6 * 6), sqrt(1 + 4 * 4 + 7 * 7), sqrt(2 * 2 + 5 * 5 + 8 * 8)})) .item()); CHECK(array_equal( - norm(arr_two_d, {1}), + norm(arr_two_d, {1}, false), array( {sqrt(0 + 1 + 2 * 2), sqrt(3 * 3 + 4 * 4 + 5 * 5), sqrt(6 * 6 + 7 * 7 + 8 * 8)})) .item()); CHECK(array_equal( - norm(arr_two_d, {0, 1}), + norm(arr_two_d, {0, 1}, false), array(sqrt( 0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + 8 * 8))) .item()); CHECK(array_equal( - norm(arr_three_d, {2}), + norm(arr_three_d, {2}, false), array( { sqrt(0 + 1 + 2 * 2), @@ -55,7 +56,7 @@ TEST_CASE("[mlx.core.linalg.norm] no ord") { {2, 3})) .item()); CHECK(array_equal( - norm(arr_three_d, {1}), + norm(arr_three_d, {1}, false), array( { sqrt(0 + 3 * 3 + 6 * 6), @@ -68,7 +69,7 @@ TEST_CASE("[mlx.core.linalg.norm] no ord") { {2, 3})) .item()); CHECK(array_equal( - norm(arr_three_d, {0}), + norm(arr_three_d, {0}, false), array( { sqrt(0 + 9 * 9), @@ -84,7 +85,7 @@ TEST_CASE("[mlx.core.linalg.norm] no ord") { {3, 3})) .item()); CHECK(array_equal( - norm(arr_three_d, {1, 2}), + norm(arr_three_d, {1, 2}, false), array( {sqrt( 0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + @@ -94,4 +95,212 @@ TEST_CASE("[mlx.core.linalg.norm] no ord") { 15 * 15 + 16 * 16 + 17 * 17)}, {2})) .item()); +} + +TEST_CASE("[mlx.core.linalg.norm] double ord") { + array arr_one_d({1, 2, 3}); + array arr_two_d = reshape(arange(9), {3, 3}); + array arr_three_d = reshape(arange(18), {2, 3, 3}); + + CHECK(array_equal(norm(arr_one_d, 2.0), array(sqrt(1 + 4 + 9))).item()); + CHECK(array_equal(norm(arr_one_d, 1.0), array(1 + 2 + 3)).item()); + CHECK(array_equal(norm(arr_one_d, 0.0), array(3)).item()); + + CHECK(array_equal(norm(arr_one_d, 2.0, {0}, false), array(sqrt(1 + 4 + 9))) + .item()); + CHECK(array_equal( + norm(arr_two_d, 2.0, {0}, false), + array( + {sqrt(0 + 3 * 3 + 6 * 6), + sqrt(1 + 4 * 4 + 7 * 7), + sqrt(2 * 2 + 5 * 5 + 8 * 8)})) + .item()); + CHECK(array_equal( + norm(arr_two_d, 2.0, {1}, false), + array( + {sqrt(0 + 1 + 2 * 2), + sqrt(3 * 3 + 4 * 4 + 5 * 5), + sqrt(6 * 6 + 7 * 7 + 8 * 8)})) + .item()); + CHECK(array_equal( + norm(arr_three_d, 2.0, {2}, false), + array( + { + sqrt(0 + 1 + 2 * 2), + sqrt(3 * 3 + 4 * 4 + 5 * 5), + sqrt(6 * 6 + 7 * 7 + 8 * 8), + sqrt(9 * 9 + 10 * 10 + 11 * 11), + sqrt(12 * 12 + 13 * 13 + 14 * 14), + sqrt(15 * 15 + 16 * 16 + 17 * 17), + }, + {2, 3})) + .item()); + CHECK(array_equal( + norm(arr_three_d, 2.0, {1}, false), + array( + { + sqrt(0 + 3 * 3 + 6 * 6), + sqrt(1 + 4 * 4 + 7 * 7), + sqrt(2 * 2 + 5 * 5 + 8 * 8), + sqrt(9 * 9 + 12 * 12 + 15 * 15), + sqrt(10 * 10 + 13 * 13 + 16 * 16), + sqrt(11 * 11 + 14 * 14 + 17 * 17), + }, + {2, 3})) + .item()); + CHECK(array_equal( + norm(arr_three_d, 2.0, {0}, false), + array( + { + sqrt(0 + 9 * 9), + sqrt(1 + 10 * 10), + sqrt(2 * 2 + 11 * 11), + sqrt(3 * 3 + 12 * 12), + sqrt(4 * 4 + 13 * 13), + sqrt(5 * 5 + 14 * 14), + sqrt(6 * 6 + 15 * 15), + sqrt(7 * 7 + 16 * 16), + sqrt(8 * 8 + 17 * 17), + }, + {3, 3})) + .item()); + + CHECK(allclose( + norm(arr_three_d, 3.0, {0}), + array( + {9., + 10.00333222, + 11.02199456, + 12.06217728, + 13.12502645, + 14.2094363, + 15.31340617, + 16.43469751, + 17.57113899}, + {3, 3})) + .item()); + CHECK( + allclose( + norm(arr_three_d, 3.0, {1}), + array( + {6.24025147, 7.41685954, 8.6401226, 18., 19.39257164, 20.7915893}, + {2, 3})) + .item()); + CHECK(allclose( + norm(arr_three_d, 3.0, {2}), + array( + {2.08008382, + 6., + 10.23127655, + 14.5180117, + 18.82291607, + 23.13593104}, + {2, 3})) + .item()); + CHECK(allclose( + norm(arr_three_d, 0.0, {0}), + array({1., 2., 2., 2., 2., 2., 2., 2., 2.}, {3, 3})) + .item()); + CHECK( + allclose( + norm(arr_three_d, 0.0, {1}), array({2., 3., 3., 3., 3., 3.}, {2, 3})) + .item()); + CHECK( + allclose( + norm(arr_three_d, 0.0, {2}), array({2., 3., 3., 3., 3., 3.}, {2, 3})) + .item()); + CHECK(allclose( + norm(arr_three_d, 1.0, {0}), + array({9., 11., 13., 15., 17., 19., 21., 23., 25.}, {3, 3})) + .item()); + CHECK(allclose( + norm(arr_three_d, 1.0, {1}), + array({9., 12., 15., 36., 39., 42.}, {2, 3})) + .item()); + CHECK(allclose( + norm(arr_three_d, 1.0, {2}), + array({3., 12., 21., 30., 39., 48.}, {2, 3})) + .item()); + + CHECK(allclose(norm(arr_two_d, 1.0, {0, 1}), array({15.0})).item()); + CHECK(allclose(norm(arr_two_d, 1.0, {1, 0}), array({21.0})).item()); + CHECK(allclose(norm(arr_two_d, -1.0, {0, 1}), array({9.0})).item()); + CHECK(allclose(norm(arr_two_d, -1.0, {1, 0}), array({3.0})).item()); + + CHECK(allclose(norm(arr_two_d, 1.0, {0, 1}, true), array({15.0}, {1, 1})) + .item()); + CHECK(allclose(norm(arr_two_d, 1.0, {1, 0}, true), array({21.0}, {1, 1})) + .item()); + CHECK(allclose(norm(arr_two_d, -1.0, {0, 1}, true), array({9.0}, {1, 1})) + .item()); + CHECK(allclose(norm(arr_two_d, -1.0, {1, 0}, true), array({3.0}, {1, 1})) + .item()); + + CHECK(array_equal(norm(arr_two_d, -1.0, {-2, -1}, false), array(9.0)) + .item()); + CHECK(array_equal(norm(arr_two_d, 1.0, {-2, -1}, false), array(15.0)) + .item()); + // + CHECK(allclose(norm(arr_three_d, 1.0, {0, 1}), array({21., 23., 25.})) + .item()); + CHECK( + allclose(norm(arr_three_d, 1.0, {1, 2}), array({15., 42.})).item()); + CHECK(allclose(norm(arr_three_d, -1.0, {0, 1}), array({9., 11., 13.})) + .item()); + CHECK( + allclose(norm(arr_three_d, -1.0, {1, 2}), array({9., 36.})).item()); + CHECK(allclose(norm(arr_three_d, -1.0, {1, 0}), array({9., 12., 15.})) + .item()); + CHECK(allclose(norm(arr_three_d, -1.0, {2, 1}), array({3, 30})).item()); + CHECK(allclose(norm(arr_three_d, -1.0, {1, 2}), array({9, 36})).item()); +} + +TEST_CASE("[mlx.core.linalg.norm] string ord") { + array arr_one_d({1, 2, 3}); + array arr_two_d = reshape(arange(9), {3, 3}); + array arr_three_d = reshape(arange(18), {2, 3, 3}); + + CHECK(allclose(norm(arr_one_d, "inf", {}), array({3.0})).item()); + CHECK(allclose(norm(arr_one_d, "-inf", {}), array({1.0})).item()); + + CHECK(allclose(norm(arr_two_d, "f", {0, 1}), array({14.2828568570857})) + .item()); + CHECK(allclose(norm(arr_two_d, "fro", {0, 1}), array({14.2828568570857})) + .item()); + CHECK(allclose(norm(arr_two_d, "inf", {0, 1}), array({21.0})).item()); + CHECK(allclose(norm(arr_two_d, "-inf", {0, 1}), array({3.0})).item()); + + CHECK(allclose( + norm(arr_three_d, "fro", {0, 1}), + array({22.24859546, 24.31049156, 26.43860813})) + .item()); + CHECK(allclose( + norm(arr_three_d, "fro", {1, 2}), array({14.28285686, 39.7617907})) + .item()); + CHECK(allclose( + norm(arr_three_d, "f", {0, 1}), + array({22.24859546, 24.31049156, 26.43860813})) + .item()); + CHECK(allclose( + norm(arr_three_d, "f", {1, 0}), + array({22.24859546, 24.31049156, 26.43860813})) + .item()); + CHECK( + allclose(norm(arr_three_d, "f", {1, 2}), array({14.28285686, 39.7617907})) + .item()); + CHECK( + allclose(norm(arr_three_d, "f", {2, 1}), array({14.28285686, 39.7617907})) + .item()); + CHECK(allclose(norm(arr_three_d, "inf", {0, 1}), array({36., 39., 42.})) + .item()); + CHECK(allclose(norm(arr_three_d, "inf", {1, 2}), array({21., 48.})) + .item()); + CHECK(allclose(norm(arr_three_d, "-inf", {0, 1}), array({9., 12., 15.})) + .item()); + CHECK(allclose(norm(arr_three_d, "-inf", {1, 2}), array({3., 30.})) + .item()); + CHECK(allclose(norm(arr_three_d, "-inf", {1, 0}), array({9., 11., 13.})) + .item()); + CHECK(allclose(norm(arr_three_d, "-inf", {2, 1}), array({9., 36.})) + .item()); } \ No newline at end of file From b996d682d9ae50193a30f6d06d3a14fd9b6d4f54 Mon Sep 17 00:00:00 2001 From: Gabrijel Boduljak Date: Thu, 21 Dec 2023 18:36:50 +0100 Subject: [PATCH 23/37] removed unused import in linalg.cpp --- mlx/linalg.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index 1847896d2..614e6f79c 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -1,7 +1,6 @@ // Copyright © 2023 Apple Inc. #include -#include #include #include #include From fa096d64a241f5284e0a96f04af3b388269d1123 Mon Sep 17 00:00:00 2001 From: Gabrijel Boduljak Date: Thu, 21 Dec 2023 19:09:36 +0100 Subject: [PATCH 24/37] updated python bindings --- python/src/linalg.cpp | 169 +++++++++++++++++++++++++++++------------- 1 file changed, 116 insertions(+), 53 deletions(-) diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp index b389ac15f..c2728c738 100644 --- a/python/src/linalg.cpp +++ b/python/src/linalg.cpp @@ -27,62 +27,125 @@ void init_linalg(py::module_& parent_module) { parent_module.def_submodule("linalg", "mlx.core.linalg: Linear Algebra."); m.def( - "vector_norm", - [](const array& a, - const std::variant& ord, - const std::variant>& axis, - bool keepdims, - StreamOrDevice s) { - std::vector axes = std::visit( - overloaded{ - [](std::monostate s) { return std::vector(); }, - [](int axis) { return std::vector({axis}); }, - [](const std::vector axes) { return axes; }}, - axis); - - if (axes.empty()) - return vector_norm(a, ord, keepdims, s); - else - return vector_norm(a, ord, axes, keepdims, s); + "norm", + [](const array& a, const bool keepdims, const StreamOrDevice stream) { + return norm(a, {}, keepdims, stream); }, "a"_a, - "ord"_a = 2.0, - "axis"_a = none, "keepdims"_a = false, "stream"_a = none, - R"pbdoc( - Computes a vector norm. + R"pbdoc()pbdoc"); - - If :attr:`axis`\ `= None`, :attr:`a` will be flattened before the norm is computed. - - If :attr:`axis` is an `int` or a `tuple`, the norm will be computed over these dimensions - and the other dimensions will be treated as batch dimensions. - - - :attr:`ord` defines the vector norm that is computed. The following norms are supported: - - ====================== =============================== - :attr:`ord` vector norm - ====================== =============================== - `2` (default) `2`-norm (see below) - `inf` `max(abs(x))` - `-inf` `min(abs(x))` - `0` `sum(x != 0)` - other `int` or `float` `sum(abs(x)^{ord})^{(1 / ord)}` - ====================== =============================== - - where `inf` refers to `float('inf')`, NumPy's `inf` object, or any equivalent object. - - Args: - a (Tensor): tensor, flattened by default, but this behavior can be - controlled using :attr:`dim`. - ord (int, float, inf, -inf, 'fro', 'nuc', optional): order of norm. Default: `2` - axis (int, Tuple[int], optional): dimensions over which to compute - the norm. See above for the behavior when :attr:`dim`\ `= None`. - Default: `None` - keepdims (bool, optional): If set to `True`, the reduced dimensions are retained - in the result as dimensions with size one. Default: `False` - - Returns: - A real-valued tensor, even when :attr:`a` is complex. - )pbdoc"); + m.def( + "norm", + [](const array& a, + const int axis, + const bool keepdims, + const StreamOrDevice stream) { + return norm(a, {axis}, keepdims, stream); + }, + "a"_a, + "axis"_a, + "keepdims"_a = false, + "stream"_a = none, + R"pbdoc()pbdoc"); + m.def( + "norm", + [](const array& a, + const std::vector& axis, + const bool keepdims, + const StreamOrDevice stream) { + return norm(a, axis, keepdims, stream); + }, + "a"_a, + "axis"_a, + "keepdims"_a = false, + "stream"_a = none, + R"pbdoc()pbdoc"); + m.def( + "norm", + [](const array& a, + const double ord, + const bool keepdims, + const StreamOrDevice stream) { + return norm(a, ord, {}, keepdims, stream); + }, + "a"_a, + "ord"_a, + "keepdims"_a = false, + "stream"_a = none, + R"pbdoc()pbdoc"); + m.def( + "norm", + [](const array& a, + const double ord, + const int axis, + const bool keepdims, + const StreamOrDevice stream) { + return norm(a, ord, {axis}, keepdims, stream); + }, + "a"_a, + "ord"_a, + "axis"_a, + "keepdims"_a = false, + "stream"_a = none, + R"pbdoc()pbdoc"); + m.def( + "norm", + [](const array& a, + const double ord, + const std::vector& axis, + const bool keepdims, + const StreamOrDevice stream) { + return norm(a, ord, axis, keepdims, stream); + }, + "a"_a, + "ord"_a, + "axis"_a, + "keepdims"_a = false, + "stream"_a = none, + R"pbdoc()pbdoc"); + m.def( + "norm", + [](const array& a, + const std::string& ord, + const bool keepdims, + const StreamOrDevice stream) { + return norm(a, ord, {}, keepdims, stream); + }, + "a"_a, + "ord"_a, + "keepdims"_a = false, + "stream"_a = none, + R"pbdoc()pbdoc"); + m.def( + "norm", + [](const array& a, + const std::string& ord, + const int axis, + const bool keepdims, + const StreamOrDevice stream) { + return norm(a, ord, {axis}, keepdims, stream); + }, + "a"_a, + "ord"_a, + "axis"_a, + "keepdims"_a = false, + "stream"_a = none, + R"pbdoc()pbdoc"); + m.def( + "norm", + [](const array& a, + const std::string& ord, + const std::vector& axis, + const bool keepdims, + const StreamOrDevice stream) { + return norm(a, ord, axis, keepdims, stream); + }, + "a"_a, + "ord"_a, + "axis"_a, + "keepdims"_a = false, + "stream"_a = none, + R"pbdoc()pbdoc"); } From 26bb16e768a9fb2dafb21974b630ec103f36f4b6 Mon Sep 17 00:00:00 2001 From: Gabrijel Boduljak Date: Fri, 22 Dec 2023 00:33:36 +0100 Subject: [PATCH 25/37] added some tests for python bindings --- python/tests/test_linalg.py | 41 +++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 python/tests/test_linalg.py diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py new file mode 100644 index 000000000..26e9587c5 --- /dev/null +++ b/python/tests/test_linalg.py @@ -0,0 +1,41 @@ +# Copyright © 2023 Apple Inc. + +import itertools +import unittest + +import mlx.core as mx +import mlx_tests +import numpy as np + + +class TestLinalg(mlx_tests.MLXTestCase): + def test_norm(self): + def check_mx_np(a_mx, a_np): + self.assertTrue(np.allclose(a_np, a_mx, atol=1e-5, rtol=1e-6)) + + x_mx = mx.arange(18).reshape((2, 3, 3)) + x_np = np.arange(18).reshape((2, 3, 3)) + + for num_axes in range(1, 3): + for axis in itertools.combinations(range(3), num_axes): + if num_axes == 1: + ords = [None, 0.5, 0, 1, 2, 3, -1, 1] + else: + ords = [None, "fro", -1, 1] + for o in ords: + for keepdims in [True, False]: + if o: + out_np = np.linalg.norm( + x_np, ord=o, axis=axis, keepdims=keepdims + ) + out_mx = mx.linalg.norm( + x_mx, ord=o, axis=axis, keepdims=keepdims + ) + else: + out_np = np.linalg.norm(x_np, axis=axis, keepdims=keepdims) + out_mx = mx.linalg.norm(x_mx, axis=axis, keepdims=keepdims) + assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) + + +if __name__ == "__main__": + unittest.main() From 49c48de53be6b72eb4dfebebb76e1404fa30f625 Mon Sep 17 00:00:00 2001 From: Gabrijel Boduljak Date: Fri, 22 Dec 2023 01:19:57 +0100 Subject: [PATCH 26/37] handling inf, -inf as numpy does, more extensive tests of compatibility with numpy --- python/src/linalg.cpp | 19 ++++++++++++ python/tests/test_linalg.py | 58 ++++++++++++++++++++++--------------- 2 files changed, 54 insertions(+), 23 deletions(-) diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp index c2728c738..00cb81dc4 100644 --- a/python/src/linalg.cpp +++ b/python/src/linalg.cpp @@ -1,6 +1,7 @@ // Copyright © 2023 Apple Inc. +#include #include #include #include @@ -68,6 +69,12 @@ void init_linalg(py::module_& parent_module) { const double ord, const bool keepdims, const StreamOrDevice stream) { + if (std::isinf((float)ord) || std::isinf(ord)) + if (ord > 0) + return norm(a, "inf", {}, keepdims, stream); + else + return norm(a, "-inf", {}, keepdims, stream); + return norm(a, ord, {}, keepdims, stream); }, "a"_a, @@ -82,6 +89,12 @@ void init_linalg(py::module_& parent_module) { const int axis, const bool keepdims, const StreamOrDevice stream) { + if (std::isinf((float)ord) || std::isinf(ord)) + if (ord > 0) + return norm(a, "inf", {axis}, keepdims, stream); + else + return norm(a, "-inf", {axis}, keepdims, stream); + return norm(a, ord, {axis}, keepdims, stream); }, "a"_a, @@ -97,6 +110,12 @@ void init_linalg(py::module_& parent_module) { const std::vector& axis, const bool keepdims, const StreamOrDevice stream) { + if (std::isinf((float)ord) || std::isinf(ord)) + if (ord > 0) + return norm(a, "inf", axis, keepdims, stream); + else + return norm(a, "-inf", axis, keepdims, stream); + return norm(a, ord, axis, keepdims, stream); }, "a"_a, diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index 26e9587c5..1969e1028 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -1,6 +1,7 @@ # Copyright © 2023 Apple Inc. import itertools +import math import unittest import mlx.core as mx @@ -10,31 +11,42 @@ import numpy as np class TestLinalg(mlx_tests.MLXTestCase): def test_norm(self): - def check_mx_np(a_mx, a_np): - self.assertTrue(np.allclose(a_np, a_mx, atol=1e-5, rtol=1e-6)) + vector_ords = [None, 0.5, 0, 1, 2, 3, -1, 1, float("inf"), -float("inf")] + matrix_ords = [None, "fro", -1, 1, float("inf"), -float("inf")] - x_mx = mx.arange(18).reshape((2, 3, 3)) - x_np = np.arange(18).reshape((2, 3, 3)) + for shape in [(3,), (2, 3), (2, 3, 3)]: + x_mx = mx.arange(math.prod(shape)).reshape(shape) + x_np = np.arange(math.prod(shape)).reshape(shape) + # Test when at least one axis is provided + for num_axes in range(1, len(shape)): + for axis in itertools.combinations(range(len(shape)), num_axes): + if num_axes == 1: + ords = vector_ords + else: + ords = matrix_ords + for o in ords: + for keepdims in [True, False]: + if o: + out_np = np.linalg.norm( + x_np, ord=o, axis=axis, keepdims=keepdims + ) + out_mx = mx.linalg.norm( + x_mx, ord=o, axis=axis, keepdims=keepdims + ) + else: + out_np = np.linalg.norm( + x_np, axis=axis, keepdims=keepdims + ) + out_mx = mx.linalg.norm( + x_mx, axis=axis, keepdims=keepdims + ) + assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) - for num_axes in range(1, 3): - for axis in itertools.combinations(range(3), num_axes): - if num_axes == 1: - ords = [None, 0.5, 0, 1, 2, 3, -1, 1] - else: - ords = [None, "fro", -1, 1] - for o in ords: - for keepdims in [True, False]: - if o: - out_np = np.linalg.norm( - x_np, ord=o, axis=axis, keepdims=keepdims - ) - out_mx = mx.linalg.norm( - x_mx, ord=o, axis=axis, keepdims=keepdims - ) - else: - out_np = np.linalg.norm(x_np, axis=axis, keepdims=keepdims) - out_mx = mx.linalg.norm(x_mx, axis=axis, keepdims=keepdims) - assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) + # Test when no axes and no ords are provided + for keepdims in [True, False]: + out_np = np.linalg.norm(x_np, keepdims=keepdims) + out_mx = mx.linalg.norm(x_mx, keepdims=keepdims) + assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) if __name__ == "__main__": From 2831c77b33c883372d7d61e7b977619bd74260a0 Mon Sep 17 00:00:00 2001 From: Gabrijel Boduljak Date: Fri, 22 Dec 2023 02:39:47 +0100 Subject: [PATCH 27/37] added better docs and examples --- docs/src/python/linalg.rst | 11 + python/src/linalg.cpp | 1253 +++++++++++++++++++++++++++++++++++- 2 files changed, 1249 insertions(+), 15 deletions(-) create mode 100644 docs/src/python/linalg.rst diff --git a/docs/src/python/linalg.rst b/docs/src/python/linalg.rst new file mode 100644 index 000000000..6c9daa100 --- /dev/null +++ b/docs/src/python/linalg.rst @@ -0,0 +1,11 @@ +.. _linalg: + +Linear Algebra +===== + +.. currentmodule:: mlx.core.linalg + +.. autosummary:: + :toctree: _autosummary + + norm \ No newline at end of file diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp index 00cb81dc4..902b196a8 100644 --- a/python/src/linalg.cpp +++ b/python/src/linalg.cpp @@ -35,8 +35,263 @@ void init_linalg(py::module_& parent_module) { "a"_a, "keepdims"_a = false, "stream"_a = none, - R"pbdoc()pbdoc"); + R"pbdoc( + Matrix or vector norm. + This function is able to return matrix or vector norms, + depending on the value of the ``ord`` parameter. + + Parameters + ---------- + a : array_like + Input array. If `axis` is None, `a` must be 1-D or 2-D, unless `ord` + is None. If both `axis` and `ord` are None, the 2-norm of ``a.flatten`` will be returned. + ord : {non-zero int, inf, -inf, 'fro', 'nuc'}, optional + Order of the norm (see table under ``Notes``). inf means float(`inf`) object. The default is None. + axis : {None, int, 2-tuple of ints}, optional. + If `axis` is an integer, it specifies the axis of `a` along which to + compute the vector norms. If `axis` is a 2-tuple, it specifies the + axes that hold 2-D matrices, and the matrix norms of these matrices + are computed. If `axis` is None then either a vector norm (when `a` + is 1-D) or a matrix norm (when `a` is 2-D) is returned. The default + is None. + keepdims : bool, optional + If this is set to True, the axes which are normed over are left in the + result as dimensions with size one. With this option the result will + broadcast correctly against the original `a`. + + Returns + ------- + n : array + Norm of the matrix or vector(s). + + Notes + ----- + For values of ``ord < 1``, the result is, strictly speaking, not a + mathematical 'norm', but it may still be useful for various numerical + purposes. + + The following norms can be calculated: + + ===== ============================ ========================== + ord norm for matrices norm for vectors + ===== ============================ ========================== + None Frobenius norm 2-norm + 'fro' Frobenius norm -- + inf max(sum(abs(x), axis=1)) max(abs(x)) + -inf min(sum(abs(x), axis=1)) min(abs(x)) + 0 -- sum(x != 0) + 1 max(sum(abs(x), axis=0)) as below + -1 min(sum(abs(x), axis=0)) as below + 2 2-norm (largest sing. value) as below + -2 smallest singular value as below + other -- sum(abs(x)**ord)**(1./ord) + ===== ============================ ========================== + + Nuclear norm and norms based on singular values are not yet implemented. + + The Frobenius norm is given by [1]_: + + :math:`||A||_F = [\\sum_{i,j} abs(a_{i,j})^2]^{1/2}` + + The nuclear norm is the sum of the singular values. + + Both the Frobenius and nuclear norm orders are only defined for + matrices and raise a ValueError when ``a.ndim != 2``. + + References + ---------- + .. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*, + Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15 + + Examples + -------- + >>> import mlx.core as mx + >>> from mlx.core import linalg as LA + >>> a = mx.arange(9) - 4 + >>> a + array([-4, -3, -2, ..., 2, 3, 4], dtype=int32) + >>> b = a.reshape((3,3)) + >>> b + array([[-4, -3, -2], + [-1, 0, 1], + [ 2, 3, 4]], dtype=int32) + >>> LA.norm(a) + array(7.74597, dtype=float32) + >>> LA.norm(b) + array(7.74597, dtype=float32) + >>> LA.norm(b, 'fro') + array(7.74597, dtype=float32) + >>> LA.norm(a, float("inf")) + array(4, dtype=int32) + >>> LA.norm(b, float("inf")) + array(9, dtype=int32) + >>> LA.norm(a, -float("inf")) + array(0, dtype=int32) + >>> LA.norm(b, -float("inf")) + array(2, dtype=int32) + >>> LA.norm(a, 1) + array(20, dtype=int32) + >>> LA.norm(b, 1) + array(7, dtype=int32) + >>> LA.norm(a, -1) + array(0, dtype=float32) + >>> LA.norm(b, -1) + array(6, dtype=int32) + >>> LA.norm(a, 2) + array(7.74597, dtype=float32) + >>> LA.norm(a, 3) + array(5.84804, dtype=float32) + >>> LA.norm(a, -3) + array(0, dtype=float32) + >>> c = mx.array([[ 1, 2, 3], + ... [-1, 1, 4]]) + >>> LA.norm(c, axis=0) + array([1.41421, 2.23607, 5], dtype=float32) + >>> LA.norm(c, axis=1) + array([3.74166, 4.24264], dtype=float32) + >>> LA.norm(c, ord=1, axis=1) + array([6, 6], dtype=int32) + >>> m = mx.arange(8).reshape(2,2,2) + array([3.74166, 11.225], dtype=float32) + >>> LA.norm(m[0, :, :]), LA.norm(m[1, :, :]) + (array(3.74166, dtype=float32), array(11.225, dtype=float32)) + )pbdoc"); + + m.def( + "norm", + [](const array& a, + const int ord, + const bool keepdims, + const StreamOrDevice stream) { + return norm(a, ord, {}, keepdims, stream); + }, + "a"_a, + "ord"_a, + "keepdims"_a = false, + "stream"_a = none, + R"pbdoc( + Matrix or vector norm. + + This function is able to return matrix or vector norms, + depending on the value of the ``ord`` parameter. + + Parameters + ---------- + a : array_like + Input array. If `axis` is None, `a` must be 1-D or 2-D, unless `ord` + is None. If both `axis` and `ord` are None, the 2-norm of ``a.flatten`` will be returned. + ord : {non-zero int, inf, -inf, 'fro', 'nuc'}, optional + Order of the norm (see table under ``Notes``). inf means float(`inf`) object. The default is None. + axis : {None, int, 2-tuple of ints}, optional. + If `axis` is an integer, it specifies the axis of `a` along which to + compute the vector norms. If `axis` is a 2-tuple, it specifies the + axes that hold 2-D matrices, and the matrix norms of these matrices + are computed. If `axis` is None then either a vector norm (when `a` + is 1-D) or a matrix norm (when `a` is 2-D) is returned. The default + is None. + keepdims : bool, optional + If this is set to True, the axes which are normed over are left in the + result as dimensions with size one. With this option the result will + broadcast correctly against the original `a`. + + Returns + ------- + n : array + Norm of the matrix or vector(s). + + Notes + ----- + For values of ``ord < 1``, the result is, strictly speaking, not a + mathematical 'norm', but it may still be useful for various numerical + purposes. + + The following norms can be calculated: + + ===== ============================ ========================== + ord norm for matrices norm for vectors + ===== ============================ ========================== + None Frobenius norm 2-norm + 'fro' Frobenius norm -- + inf max(sum(abs(x), axis=1)) max(abs(x)) + -inf min(sum(abs(x), axis=1)) min(abs(x)) + 0 -- sum(x != 0) + 1 max(sum(abs(x), axis=0)) as below + -1 min(sum(abs(x), axis=0)) as below + 2 2-norm (largest sing. value) as below + -2 smallest singular value as below + other -- sum(abs(x)**ord)**(1./ord) + ===== ============================ ========================== + + Nuclear norm and norms based on singular values are not yet implemented. + + The Frobenius norm is given by [1]_: + + :math:`||A||_F = [\\sum_{i,j} abs(a_{i,j})^2]^{1/2}` + + The nuclear norm is the sum of the singular values. + + Both the Frobenius and nuclear norm orders are only defined for + matrices and raise a ValueError when ``a.ndim != 2``. + + References + ---------- + .. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*, + Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15 + + Examples + -------- + >>> import mlx.core as mx + >>> from mlx.core import linalg as LA + >>> a = mx.arange(9) - 4 + >>> a + array([-4, -3, -2, ..., 2, 3, 4], dtype=int32) + >>> b = a.reshape((3,3)) + >>> b + array([[-4, -3, -2], + [-1, 0, 1], + [ 2, 3, 4]], dtype=int32) + >>> LA.norm(a) + array(7.74597, dtype=float32) + >>> LA.norm(b) + array(7.74597, dtype=float32) + >>> LA.norm(b, 'fro') + array(7.74597, dtype=float32) + >>> LA.norm(a, float("inf")) + array(4, dtype=int32) + >>> LA.norm(b, float("inf")) + array(9, dtype=int32) + >>> LA.norm(a, -float("inf")) + array(0, dtype=int32) + >>> LA.norm(b, -float("inf")) + array(2, dtype=int32) + >>> LA.norm(a, 1) + array(20, dtype=int32) + >>> LA.norm(b, 1) + array(7, dtype=int32) + >>> LA.norm(a, -1) + array(0, dtype=float32) + >>> LA.norm(b, -1) + array(6, dtype=int32) + >>> LA.norm(a, 2) + array(7.74597, dtype=float32) + >>> LA.norm(a, 3) + array(5.84804, dtype=float32) + >>> LA.norm(a, -3) + array(0, dtype=float32) + >>> c = mx.array([[ 1, 2, 3], + ... [-1, 1, 4]]) + >>> LA.norm(c, axis=0) + array([1.41421, 2.23607, 5], dtype=float32) + >>> LA.norm(c, axis=1) + array([3.74166, 4.24264], dtype=float32) + >>> LA.norm(c, ord=1, axis=1) + array([6, 6], dtype=int32) + >>> m = mx.arange(8).reshape(2,2,2) + array([3.74166, 11.225], dtype=float32) + >>> LA.norm(m[0, :, :]), LA.norm(m[1, :, :]) + (array(3.74166, dtype=float32), array(11.225, dtype=float32)) + )pbdoc"); m.def( "norm", [](const array& a, @@ -49,7 +304,128 @@ void init_linalg(py::module_& parent_module) { "axis"_a, "keepdims"_a = false, "stream"_a = none, - R"pbdoc()pbdoc"); + R"pbdoc( + Matrix or vector norm. + + This function is able to return matrix or vector norms, + depending on the value of the ``ord`` parameter. + + Parameters + ---------- + a : array_like + Input array. If `axis` is None, `a` must be 1-D or 2-D, unless `ord` + is None. If both `axis` and `ord` are None, the 2-norm of ``a.flatten`` will be returned. + ord : {non-zero int, inf, -inf, 'fro', 'nuc'}, optional + Order of the norm (see table under ``Notes``). inf means float(`inf`) object. The default is None. + axis : {None, int, 2-tuple of ints}, optional. + If `axis` is an integer, it specifies the axis of `a` along which to + compute the vector norms. If `axis` is a 2-tuple, it specifies the + axes that hold 2-D matrices, and the matrix norms of these matrices + are computed. If `axis` is None then either a vector norm (when `a` + is 1-D) or a matrix norm (when `a` is 2-D) is returned. The default + is None. + keepdims : bool, optional + If this is set to True, the axes which are normed over are left in the + result as dimensions with size one. With this option the result will + broadcast correctly against the original `a`. + + Returns + ------- + n : array + Norm of the matrix or vector(s). + + Notes + ----- + For values of ``ord < 1``, the result is, strictly speaking, not a + mathematical 'norm', but it may still be useful for various numerical + purposes. + + The following norms can be calculated: + + ===== ============================ ========================== + ord norm for matrices norm for vectors + ===== ============================ ========================== + None Frobenius norm 2-norm + 'fro' Frobenius norm -- + inf max(sum(abs(x), axis=1)) max(abs(x)) + -inf min(sum(abs(x), axis=1)) min(abs(x)) + 0 -- sum(x != 0) + 1 max(sum(abs(x), axis=0)) as below + -1 min(sum(abs(x), axis=0)) as below + 2 2-norm (largest sing. value) as below + -2 smallest singular value as below + other -- sum(abs(x)**ord)**(1./ord) + ===== ============================ ========================== + + Nuclear norm and norms based on singular values are not yet implemented. + + The Frobenius norm is given by [1]_: + + :math:`||A||_F = [\\sum_{i,j} abs(a_{i,j})^2]^{1/2}` + + The nuclear norm is the sum of the singular values. + + Both the Frobenius and nuclear norm orders are only defined for + matrices and raise a ValueError when ``a.ndim != 2``. + + References + ---------- + .. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*, + Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15 + + Examples + -------- + >>> import mlx.core as mx + >>> from mlx.core import linalg as LA + >>> a = mx.arange(9) - 4 + >>> a + array([-4, -3, -2, ..., 2, 3, 4], dtype=int32) + >>> b = a.reshape((3,3)) + >>> b + array([[-4, -3, -2], + [-1, 0, 1], + [ 2, 3, 4]], dtype=int32) + >>> LA.norm(a) + array(7.74597, dtype=float32) + >>> LA.norm(b) + array(7.74597, dtype=float32) + >>> LA.norm(b, 'fro') + array(7.74597, dtype=float32) + >>> LA.norm(a, float("inf")) + array(4, dtype=int32) + >>> LA.norm(b, float("inf")) + array(9, dtype=int32) + >>> LA.norm(a, -float("inf")) + array(0, dtype=int32) + >>> LA.norm(b, -float("inf")) + array(2, dtype=int32) + >>> LA.norm(a, 1) + array(20, dtype=int32) + >>> LA.norm(b, 1) + array(7, dtype=int32) + >>> LA.norm(a, -1) + array(0, dtype=float32) + >>> LA.norm(b, -1) + array(6, dtype=int32) + >>> LA.norm(a, 2) + array(7.74597, dtype=float32) + >>> LA.norm(a, 3) + array(5.84804, dtype=float32) + >>> LA.norm(a, -3) + array(0, dtype=float32) + >>> c = mx.array([[ 1, 2, 3], + ... [-1, 1, 4]]) + >>> LA.norm(c, axis=0) + array([1.41421, 2.23607, 5], dtype=float32) + >>> LA.norm(c, axis=1) + array([3.74166, 4.24264], dtype=float32) + >>> LA.norm(c, ord=1, axis=1) + array([6, 6], dtype=int32) + >>> m = mx.arange(8).reshape(2,2,2) + array([3.74166, 11.225], dtype=float32) + >>> LA.norm(m[0, :, :]), LA.norm(m[1, :, :]) + (array(3.74166, dtype=float32), array(11.225, dtype=float32)) + )pbdoc"); m.def( "norm", [](const array& a, @@ -62,26 +438,268 @@ void init_linalg(py::module_& parent_module) { "axis"_a, "keepdims"_a = false, "stream"_a = none, - R"pbdoc()pbdoc"); + R"pbdoc( + Matrix or vector norm. + + This function is able to return matrix or vector norms, + depending on the value of the ``ord`` parameter. + + Parameters + ---------- + a : array_like + Input array. If `axis` is None, `a` must be 1-D or 2-D, unless `ord` + is None. If both `axis` and `ord` are None, the 2-norm of ``a.flatten`` will be returned. + ord : {non-zero int, inf, -inf, 'fro', 'nuc'}, optional + Order of the norm (see table under ``Notes``). inf means float(`inf`) object. The default is None. + axis : {None, int, 2-tuple of ints}, optional. + If `axis` is an integer, it specifies the axis of `a` along which to + compute the vector norms. If `axis` is a 2-tuple, it specifies the + axes that hold 2-D matrices, and the matrix norms of these matrices + are computed. If `axis` is None then either a vector norm (when `a` + is 1-D) or a matrix norm (when `a` is 2-D) is returned. The default + is None. + keepdims : bool, optional + If this is set to True, the axes which are normed over are left in the + result as dimensions with size one. With this option the result will + broadcast correctly against the original `a`. + + Returns + ------- + n : array + Norm of the matrix or vector(s). + + Notes + ----- + For values of ``ord < 1``, the result is, strictly speaking, not a + mathematical 'norm', but it may still be useful for various numerical + purposes. + + The following norms can be calculated: + + ===== ============================ ========================== + ord norm for matrices norm for vectors + ===== ============================ ========================== + None Frobenius norm 2-norm + 'fro' Frobenius norm -- + inf max(sum(abs(x), axis=1)) max(abs(x)) + -inf min(sum(abs(x), axis=1)) min(abs(x)) + 0 -- sum(x != 0) + 1 max(sum(abs(x), axis=0)) as below + -1 min(sum(abs(x), axis=0)) as below + 2 2-norm (largest sing. value) as below + -2 smallest singular value as below + other -- sum(abs(x)**ord)**(1./ord) + ===== ============================ ========================== + + Nuclear norm and norms based on singular values are not yet implemented. + + The Frobenius norm is given by [1]_: + + :math:`||A||_F = [\\sum_{i,j} abs(a_{i,j})^2]^{1/2}` + + The nuclear norm is the sum of the singular values. + + Both the Frobenius and nuclear norm orders are only defined for + matrices and raise a ValueError when ``a.ndim != 2``. + + References + ---------- + .. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*, + Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15 + + Examples + -------- + >>> import mlx.core as mx + >>> from mlx.core import linalg as LA + >>> a = mx.arange(9) - 4 + >>> a + array([-4, -3, -2, ..., 2, 3, 4], dtype=int32) + >>> b = a.reshape((3,3)) + >>> b + array([[-4, -3, -2], + [-1, 0, 1], + [ 2, 3, 4]], dtype=int32) + >>> LA.norm(a) + array(7.74597, dtype=float32) + >>> LA.norm(b) + array(7.74597, dtype=float32) + >>> LA.norm(b, 'fro') + array(7.74597, dtype=float32) + >>> LA.norm(a, float("inf")) + array(4, dtype=int32) + >>> LA.norm(b, float("inf")) + array(9, dtype=int32) + >>> LA.norm(a, -float("inf")) + array(0, dtype=int32) + >>> LA.norm(b, -float("inf")) + array(2, dtype=int32) + >>> LA.norm(a, 1) + array(20, dtype=int32) + >>> LA.norm(b, 1) + array(7, dtype=int32) + >>> LA.norm(a, -1) + array(0, dtype=float32) + >>> LA.norm(b, -1) + array(6, dtype=int32) + >>> LA.norm(a, 2) + array(7.74597, dtype=float32) + >>> LA.norm(a, 3) + array(5.84804, dtype=float32) + >>> LA.norm(a, -3) + array(0, dtype=float32) + >>> c = mx.array([[ 1, 2, 3], + ... [-1, 1, 4]]) + >>> LA.norm(c, axis=0) + array([1.41421, 2.23607, 5], dtype=float32) + >>> LA.norm(c, axis=1) + array([3.74166, 4.24264], dtype=float32) + >>> LA.norm(c, ord=1, axis=1) + array([6, 6], dtype=int32) + >>> m = mx.arange(8).reshape(2,2,2) + array([3.74166, 11.225], dtype=float32) + >>> LA.norm(m[0, :, :]), LA.norm(m[1, :, :]) + (array(3.74166, dtype=float32), array(11.225, dtype=float32)) + )pbdoc"); m.def( "norm", [](const array& a, const double ord, const bool keepdims, const StreamOrDevice stream) { - if (std::isinf((float)ord) || std::isinf(ord)) + if (std::isinf((float)ord) || std::isinf(ord)) { if (ord > 0) return norm(a, "inf", {}, keepdims, stream); else return norm(a, "-inf", {}, keepdims, stream); - + } return norm(a, ord, {}, keepdims, stream); }, "a"_a, "ord"_a, "keepdims"_a = false, "stream"_a = none, - R"pbdoc()pbdoc"); + R"pbdoc( + Matrix or vector norm. + + This function is able to return matrix or vector norms, + depending on the value of the ``ord`` parameter. + + Parameters + ---------- + a : array_like + Input array. If `axis` is None, `a` must be 1-D or 2-D, unless `ord` + is None. If both `axis` and `ord` are None, the 2-norm of ``a.flatten`` will be returned. + ord : {non-zero int, inf, -inf, 'fro', 'nuc'}, optional + Order of the norm (see table under ``Notes``). inf means float(`inf`) object. The default is None. + axis : {None, int, 2-tuple of ints}, optional. + If `axis` is an integer, it specifies the axis of `a` along which to + compute the vector norms. If `axis` is a 2-tuple, it specifies the + axes that hold 2-D matrices, and the matrix norms of these matrices + are computed. If `axis` is None then either a vector norm (when `a` + is 1-D) or a matrix norm (when `a` is 2-D) is returned. The default + is None. + keepdims : bool, optional + If this is set to True, the axes which are normed over are left in the + result as dimensions with size one. With this option the result will + broadcast correctly against the original `a`. + + Returns + ------- + n : array + Norm of the matrix or vector(s). + + Notes + ----- + For values of ``ord < 1``, the result is, strictly speaking, not a + mathematical 'norm', but it may still be useful for various numerical + purposes. + + The following norms can be calculated: + + ===== ============================ ========================== + ord norm for matrices norm for vectors + ===== ============================ ========================== + None Frobenius norm 2-norm + 'fro' Frobenius norm -- + inf max(sum(abs(x), axis=1)) max(abs(x)) + -inf min(sum(abs(x), axis=1)) min(abs(x)) + 0 -- sum(x != 0) + 1 max(sum(abs(x), axis=0)) as below + -1 min(sum(abs(x), axis=0)) as below + 2 2-norm (largest sing. value) as below + -2 smallest singular value as below + other -- sum(abs(x)**ord)**(1./ord) + ===== ============================ ========================== + + Nuclear norm and norms based on singular values are not yet implemented. + + The Frobenius norm is given by [1]_: + + :math:`||A||_F = [\\sum_{i,j} abs(a_{i,j})^2]^{1/2}` + + The nuclear norm is the sum of the singular values. + + Both the Frobenius and nuclear norm orders are only defined for + matrices and raise a ValueError when ``a.ndim != 2``. + + References + ---------- + .. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*, + Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15 + + Examples + -------- + >>> import mlx.core as mx + >>> from mlx.core import linalg as LA + >>> a = mx.arange(9) - 4 + >>> a + array([-4, -3, -2, ..., 2, 3, 4], dtype=int32) + >>> b = a.reshape((3,3)) + >>> b + array([[-4, -3, -2], + [-1, 0, 1], + [ 2, 3, 4]], dtype=int32) + >>> LA.norm(a) + array(7.74597, dtype=float32) + >>> LA.norm(b) + array(7.74597, dtype=float32) + >>> LA.norm(b, 'fro') + array(7.74597, dtype=float32) + >>> LA.norm(a, float("inf")) + array(4, dtype=int32) + >>> LA.norm(b, float("inf")) + array(9, dtype=int32) + >>> LA.norm(a, -float("inf")) + array(0, dtype=int32) + >>> LA.norm(b, -float("inf")) + array(2, dtype=int32) + >>> LA.norm(a, 1) + array(20, dtype=int32) + >>> LA.norm(b, 1) + array(7, dtype=int32) + >>> LA.norm(a, -1) + array(0, dtype=float32) + >>> LA.norm(b, -1) + array(6, dtype=int32) + >>> LA.norm(a, 2) + array(7.74597, dtype=float32) + >>> LA.norm(a, 3) + array(5.84804, dtype=float32) + >>> LA.norm(a, -3) + array(0, dtype=float32) + >>> c = mx.array([[ 1, 2, 3], + ... [-1, 1, 4]]) + >>> LA.norm(c, axis=0) + array([1.41421, 2.23607, 5], dtype=float32) + >>> LA.norm(c, axis=1) + array([3.74166, 4.24264], dtype=float32) + >>> LA.norm(c, ord=1, axis=1) + array([6, 6], dtype=int32) + >>> m = mx.arange(8).reshape(2,2,2) + array([3.74166, 11.225], dtype=float32) + >>> LA.norm(m[0, :, :]), LA.norm(m[1, :, :]) + (array(3.74166, dtype=float32), array(11.225, dtype=float32)) + )pbdoc"); m.def( "norm", [](const array& a, @@ -89,12 +707,12 @@ void init_linalg(py::module_& parent_module) { const int axis, const bool keepdims, const StreamOrDevice stream) { - if (std::isinf((float)ord) || std::isinf(ord)) + if (std::isinf((float)ord) || std::isinf(ord)) { if (ord > 0) return norm(a, "inf", {axis}, keepdims, stream); else return norm(a, "-inf", {axis}, keepdims, stream); - + } return norm(a, ord, {axis}, keepdims, stream); }, "a"_a, @@ -102,7 +720,128 @@ void init_linalg(py::module_& parent_module) { "axis"_a, "keepdims"_a = false, "stream"_a = none, - R"pbdoc()pbdoc"); + R"pbdoc( + Matrix or vector norm. + + This function is able to return matrix or vector norms, + depending on the value of the ``ord`` parameter. + + Parameters + ---------- + a : array_like + Input array. If `axis` is None, `a` must be 1-D or 2-D, unless `ord` + is None. If both `axis` and `ord` are None, the 2-norm of ``a.flatten`` will be returned. + ord : {non-zero int, inf, -inf, 'fro', 'nuc'}, optional + Order of the norm (see table under ``Notes``). inf means float(`inf`) object. The default is None. + axis : {None, int, 2-tuple of ints}, optional. + If `axis` is an integer, it specifies the axis of `a` along which to + compute the vector norms. If `axis` is a 2-tuple, it specifies the + axes that hold 2-D matrices, and the matrix norms of these matrices + are computed. If `axis` is None then either a vector norm (when `a` + is 1-D) or a matrix norm (when `a` is 2-D) is returned. The default + is None. + keepdims : bool, optional + If this is set to True, the axes which are normed over are left in the + result as dimensions with size one. With this option the result will + broadcast correctly against the original `a`. + + Returns + ------- + n : array + Norm of the matrix or vector(s). + + Notes + ----- + For values of ``ord < 1``, the result is, strictly speaking, not a + mathematical 'norm', but it may still be useful for various numerical + purposes. + + The following norms can be calculated: + + ===== ============================ ========================== + ord norm for matrices norm for vectors + ===== ============================ ========================== + None Frobenius norm 2-norm + 'fro' Frobenius norm -- + inf max(sum(abs(x), axis=1)) max(abs(x)) + -inf min(sum(abs(x), axis=1)) min(abs(x)) + 0 -- sum(x != 0) + 1 max(sum(abs(x), axis=0)) as below + -1 min(sum(abs(x), axis=0)) as below + 2 2-norm (largest sing. value) as below + -2 smallest singular value as below + other -- sum(abs(x)**ord)**(1./ord) + ===== ============================ ========================== + + Nuclear norm and norms based on singular values are not yet implemented. + + The Frobenius norm is given by [1]_: + + :math:`||A||_F = [\\sum_{i,j} abs(a_{i,j})^2]^{1/2}` + + The nuclear norm is the sum of the singular values. + + Both the Frobenius and nuclear norm orders are only defined for + matrices and raise a ValueError when ``a.ndim != 2``. + + References + ---------- + .. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*, + Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15 + + Examples + -------- + >>> import mlx.core as mx + >>> from mlx.core import linalg as LA + >>> a = mx.arange(9) - 4 + >>> a + array([-4, -3, -2, ..., 2, 3, 4], dtype=int32) + >>> b = a.reshape((3,3)) + >>> b + array([[-4, -3, -2], + [-1, 0, 1], + [ 2, 3, 4]], dtype=int32) + >>> LA.norm(a) + array(7.74597, dtype=float32) + >>> LA.norm(b) + array(7.74597, dtype=float32) + >>> LA.norm(b, 'fro') + array(7.74597, dtype=float32) + >>> LA.norm(a, float("inf")) + array(4, dtype=int32) + >>> LA.norm(b, float("inf")) + array(9, dtype=int32) + >>> LA.norm(a, -float("inf")) + array(0, dtype=int32) + >>> LA.norm(b, -float("inf")) + array(2, dtype=int32) + >>> LA.norm(a, 1) + array(20, dtype=int32) + >>> LA.norm(b, 1) + array(7, dtype=int32) + >>> LA.norm(a, -1) + array(0, dtype=float32) + >>> LA.norm(b, -1) + array(6, dtype=int32) + >>> LA.norm(a, 2) + array(7.74597, dtype=float32) + >>> LA.norm(a, 3) + array(5.84804, dtype=float32) + >>> LA.norm(a, -3) + array(0, dtype=float32) + >>> c = mx.array([[ 1, 2, 3], + ... [-1, 1, 4]]) + >>> LA.norm(c, axis=0) + array([1.41421, 2.23607, 5], dtype=float32) + >>> LA.norm(c, axis=1) + array([3.74166, 4.24264], dtype=float32) + >>> LA.norm(c, ord=1, axis=1) + array([6, 6], dtype=int32) + >>> m = mx.arange(8).reshape(2,2,2) + array([3.74166, 11.225], dtype=float32) + >>> LA.norm(m[0, :, :]), LA.norm(m[1, :, :]) + (array(3.74166, dtype=float32), array(11.225, dtype=float32)) + )pbdoc"); m.def( "norm", [](const array& a, @@ -110,12 +849,12 @@ void init_linalg(py::module_& parent_module) { const std::vector& axis, const bool keepdims, const StreamOrDevice stream) { - if (std::isinf((float)ord) || std::isinf(ord)) + if (std::isinf((float)ord) || std::isinf(ord)) { if (ord > 0) return norm(a, "inf", axis, keepdims, stream); else return norm(a, "-inf", axis, keepdims, stream); - + } return norm(a, ord, axis, keepdims, stream); }, "a"_a, @@ -123,7 +862,128 @@ void init_linalg(py::module_& parent_module) { "axis"_a, "keepdims"_a = false, "stream"_a = none, - R"pbdoc()pbdoc"); + R"pbdoc( + Matrix or vector norm. + + This function is able to return matrix or vector norms, + depending on the value of the ``ord`` parameter. + + Parameters + ---------- + a : array_like + Input array. If `axis` is None, `a` must be 1-D or 2-D, unless `ord` + is None. If both `axis` and `ord` are None, the 2-norm of ``a.flatten`` will be returned. + ord : {non-zero int, inf, -inf, 'fro', 'nuc'}, optional + Order of the norm (see table under ``Notes``). inf means float(`inf`) object. The default is None. + axis : {None, int, 2-tuple of ints}, optional. + If `axis` is an integer, it specifies the axis of `a` along which to + compute the vector norms. If `axis` is a 2-tuple, it specifies the + axes that hold 2-D matrices, and the matrix norms of these matrices + are computed. If `axis` is None then either a vector norm (when `a` + is 1-D) or a matrix norm (when `a` is 2-D) is returned. The default + is None. + keepdims : bool, optional + If this is set to True, the axes which are normed over are left in the + result as dimensions with size one. With this option the result will + broadcast correctly against the original `a`. + + Returns + ------- + n : array + Norm of the matrix or vector(s). + + Notes + ----- + For values of ``ord < 1``, the result is, strictly speaking, not a + mathematical 'norm', but it may still be useful for various numerical + purposes. + + The following norms can be calculated: + + ===== ============================ ========================== + ord norm for matrices norm for vectors + ===== ============================ ========================== + None Frobenius norm 2-norm + 'fro' Frobenius norm -- + inf max(sum(abs(x), axis=1)) max(abs(x)) + -inf min(sum(abs(x), axis=1)) min(abs(x)) + 0 -- sum(x != 0) + 1 max(sum(abs(x), axis=0)) as below + -1 min(sum(abs(x), axis=0)) as below + 2 2-norm (largest sing. value) as below + -2 smallest singular value as below + other -- sum(abs(x)**ord)**(1./ord) + ===== ============================ ========================== + + Nuclear norm and norms based on singular values are not yet implemented. + + The Frobenius norm is given by [1]_: + + :math:`||A||_F = [\\sum_{i,j} abs(a_{i,j})^2]^{1/2}` + + The nuclear norm is the sum of the singular values. + + Both the Frobenius and nuclear norm orders are only defined for + matrices and raise a ValueError when ``a.ndim != 2``. + + References + ---------- + .. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*, + Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15 + + Examples + -------- + >>> import mlx.core as mx + >>> from mlx.core import linalg as LA + >>> a = mx.arange(9) - 4 + >>> a + array([-4, -3, -2, ..., 2, 3, 4], dtype=int32) + >>> b = a.reshape((3,3)) + >>> b + array([[-4, -3, -2], + [-1, 0, 1], + [ 2, 3, 4]], dtype=int32) + >>> LA.norm(a) + array(7.74597, dtype=float32) + >>> LA.norm(b) + array(7.74597, dtype=float32) + >>> LA.norm(b, 'fro') + array(7.74597, dtype=float32) + >>> LA.norm(a, float("inf")) + array(4, dtype=int32) + >>> LA.norm(b, float("inf")) + array(9, dtype=int32) + >>> LA.norm(a, -float("inf")) + array(0, dtype=int32) + >>> LA.norm(b, -float("inf")) + array(2, dtype=int32) + >>> LA.norm(a, 1) + array(20, dtype=int32) + >>> LA.norm(b, 1) + array(7, dtype=int32) + >>> LA.norm(a, -1) + array(0, dtype=float32) + >>> LA.norm(b, -1) + array(6, dtype=int32) + >>> LA.norm(a, 2) + array(7.74597, dtype=float32) + >>> LA.norm(a, 3) + array(5.84804, dtype=float32) + >>> LA.norm(a, -3) + array(0, dtype=float32) + >>> c = mx.array([[ 1, 2, 3], + ... [-1, 1, 4]]) + >>> LA.norm(c, axis=0) + array([1.41421, 2.23607, 5], dtype=float32) + >>> LA.norm(c, axis=1) + array([3.74166, 4.24264], dtype=float32) + >>> LA.norm(c, ord=1, axis=1) + array([6, 6], dtype=int32) + >>> m = mx.arange(8).reshape(2,2,2) + array([3.74166, 11.225], dtype=float32) + >>> LA.norm(m[0, :, :]), LA.norm(m[1, :, :]) + (array(3.74166, dtype=float32), array(11.225, dtype=float32)) + )pbdoc"); m.def( "norm", [](const array& a, @@ -136,7 +996,128 @@ void init_linalg(py::module_& parent_module) { "ord"_a, "keepdims"_a = false, "stream"_a = none, - R"pbdoc()pbdoc"); + R"pbdoc( + Matrix or vector norm. + + This function is able to return matrix or vector norms, + depending on the value of the ``ord`` parameter. + + Parameters + ---------- + a : array_like + Input array. If `axis` is None, `a` must be 1-D or 2-D, unless `ord` + is None. If both `axis` and `ord` are None, the 2-norm of ``a.flatten`` will be returned. + ord : {non-zero int, inf, -inf, 'fro', 'nuc'}, optional + Order of the norm (see table under ``Notes``). inf means float(`inf`) object. The default is None. + axis : {None, int, 2-tuple of ints}, optional. + If `axis` is an integer, it specifies the axis of `a` along which to + compute the vector norms. If `axis` is a 2-tuple, it specifies the + axes that hold 2-D matrices, and the matrix norms of these matrices + are computed. If `axis` is None then either a vector norm (when `a` + is 1-D) or a matrix norm (when `a` is 2-D) is returned. The default + is None. + keepdims : bool, optional + If this is set to True, the axes which are normed over are left in the + result as dimensions with size one. With this option the result will + broadcast correctly against the original `a`. + + Returns + ------- + n : array + Norm of the matrix or vector(s). + + Notes + ----- + For values of ``ord < 1``, the result is, strictly speaking, not a + mathematical 'norm', but it may still be useful for various numerical + purposes. + + The following norms can be calculated: + + ===== ============================ ========================== + ord norm for matrices norm for vectors + ===== ============================ ========================== + None Frobenius norm 2-norm + 'fro' Frobenius norm -- + inf max(sum(abs(x), axis=1)) max(abs(x)) + -inf min(sum(abs(x), axis=1)) min(abs(x)) + 0 -- sum(x != 0) + 1 max(sum(abs(x), axis=0)) as below + -1 min(sum(abs(x), axis=0)) as below + 2 2-norm (largest sing. value) as below + -2 smallest singular value as below + other -- sum(abs(x)**ord)**(1./ord) + ===== ============================ ========================== + + Nuclear norm and norms based on singular values are not yet implemented. + + The Frobenius norm is given by [1]_: + + :math:`||A||_F = [\\sum_{i,j} abs(a_{i,j})^2]^{1/2}` + + The nuclear norm is the sum of the singular values. + + Both the Frobenius and nuclear norm orders are only defined for + matrices and raise a ValueError when ``a.ndim != 2``. + + References + ---------- + .. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*, + Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15 + + Examples + -------- + >>> import mlx.core as mx + >>> from mlx.core import linalg as LA + >>> a = mx.arange(9) - 4 + >>> a + array([-4, -3, -2, ..., 2, 3, 4], dtype=int32) + >>> b = a.reshape((3,3)) + >>> b + array([[-4, -3, -2], + [-1, 0, 1], + [ 2, 3, 4]], dtype=int32) + >>> LA.norm(a) + array(7.74597, dtype=float32) + >>> LA.norm(b) + array(7.74597, dtype=float32) + >>> LA.norm(b, 'fro') + array(7.74597, dtype=float32) + >>> LA.norm(a, float("inf")) + array(4, dtype=int32) + >>> LA.norm(b, float("inf")) + array(9, dtype=int32) + >>> LA.norm(a, -float("inf")) + array(0, dtype=int32) + >>> LA.norm(b, -float("inf")) + array(2, dtype=int32) + >>> LA.norm(a, 1) + array(20, dtype=int32) + >>> LA.norm(b, 1) + array(7, dtype=int32) + >>> LA.norm(a, -1) + array(0, dtype=float32) + >>> LA.norm(b, -1) + array(6, dtype=int32) + >>> LA.norm(a, 2) + array(7.74597, dtype=float32) + >>> LA.norm(a, 3) + array(5.84804, dtype=float32) + >>> LA.norm(a, -3) + array(0, dtype=float32) + >>> c = mx.array([[ 1, 2, 3], + ... [-1, 1, 4]]) + >>> LA.norm(c, axis=0) + array([1.41421, 2.23607, 5], dtype=float32) + >>> LA.norm(c, axis=1) + array([3.74166, 4.24264], dtype=float32) + >>> LA.norm(c, ord=1, axis=1) + array([6, 6], dtype=int32) + >>> m = mx.arange(8).reshape(2,2,2) + array([3.74166, 11.225], dtype=float32) + >>> LA.norm(m[0, :, :]), LA.norm(m[1, :, :]) + (array(3.74166, dtype=float32), array(11.225, dtype=float32)) + )pbdoc"); m.def( "norm", [](const array& a, @@ -151,7 +1132,128 @@ void init_linalg(py::module_& parent_module) { "axis"_a, "keepdims"_a = false, "stream"_a = none, - R"pbdoc()pbdoc"); + R"pbdoc( + Matrix or vector norm. + + This function is able to return matrix or vector norms, + depending on the value of the ``ord`` parameter. + + Parameters + ---------- + a : array_like + Input array. If `axis` is None, `a` must be 1-D or 2-D, unless `ord` + is None. If both `axis` and `ord` are None, the 2-norm of ``a.flatten`` will be returned. + ord : {non-zero int, inf, -inf, 'fro', 'nuc'}, optional + Order of the norm (see table under ``Notes``). inf means float(`inf`) object. The default is None. + axis : {None, int, 2-tuple of ints}, optional. + If `axis` is an integer, it specifies the axis of `a` along which to + compute the vector norms. If `axis` is a 2-tuple, it specifies the + axes that hold 2-D matrices, and the matrix norms of these matrices + are computed. If `axis` is None then either a vector norm (when `a` + is 1-D) or a matrix norm (when `a` is 2-D) is returned. The default + is None. + keepdims : bool, optional + If this is set to True, the axes which are normed over are left in the + result as dimensions with size one. With this option the result will + broadcast correctly against the original `a`. + + Returns + ------- + n : array + Norm of the matrix or vector(s). + + Notes + ----- + For values of ``ord < 1``, the result is, strictly speaking, not a + mathematical 'norm', but it may still be useful for various numerical + purposes. + + The following norms can be calculated: + + ===== ============================ ========================== + ord norm for matrices norm for vectors + ===== ============================ ========================== + None Frobenius norm 2-norm + 'fro' Frobenius norm -- + inf max(sum(abs(x), axis=1)) max(abs(x)) + -inf min(sum(abs(x), axis=1)) min(abs(x)) + 0 -- sum(x != 0) + 1 max(sum(abs(x), axis=0)) as below + -1 min(sum(abs(x), axis=0)) as below + 2 2-norm (largest sing. value) as below + -2 smallest singular value as below + other -- sum(abs(x)**ord)**(1./ord) + ===== ============================ ========================== + + Nuclear norm and norms based on singular values are not yet implemented. + + The Frobenius norm is given by [1]_: + + :math:`||A||_F = [\\sum_{i,j} abs(a_{i,j})^2]^{1/2}` + + The nuclear norm is the sum of the singular values. + + Both the Frobenius and nuclear norm orders are only defined for + matrices and raise a ValueError when ``a.ndim != 2``. + + References + ---------- + .. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*, + Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15 + + Examples + -------- + >>> import mlx.core as mx + >>> from mlx.core import linalg as LA + >>> a = mx.arange(9) - 4 + >>> a + array([-4, -3, -2, ..., 2, 3, 4], dtype=int32) + >>> b = a.reshape((3,3)) + >>> b + array([[-4, -3, -2], + [-1, 0, 1], + [ 2, 3, 4]], dtype=int32) + >>> LA.norm(a) + array(7.74597, dtype=float32) + >>> LA.norm(b) + array(7.74597, dtype=float32) + >>> LA.norm(b, 'fro') + array(7.74597, dtype=float32) + >>> LA.norm(a, float("inf")) + array(4, dtype=int32) + >>> LA.norm(b, float("inf")) + array(9, dtype=int32) + >>> LA.norm(a, -float("inf")) + array(0, dtype=int32) + >>> LA.norm(b, -float("inf")) + array(2, dtype=int32) + >>> LA.norm(a, 1) + array(20, dtype=int32) + >>> LA.norm(b, 1) + array(7, dtype=int32) + >>> LA.norm(a, -1) + array(0, dtype=float32) + >>> LA.norm(b, -1) + array(6, dtype=int32) + >>> LA.norm(a, 2) + array(7.74597, dtype=float32) + >>> LA.norm(a, 3) + array(5.84804, dtype=float32) + >>> LA.norm(a, -3) + array(0, dtype=float32) + >>> c = mx.array([[ 1, 2, 3], + ... [-1, 1, 4]]) + >>> LA.norm(c, axis=0) + array([1.41421, 2.23607, 5], dtype=float32) + >>> LA.norm(c, axis=1) + array([3.74166, 4.24264], dtype=float32) + >>> LA.norm(c, ord=1, axis=1) + array([6, 6], dtype=int32) + >>> m = mx.arange(8).reshape(2,2,2) + array([3.74166, 11.225], dtype=float32) + >>> LA.norm(m[0, :, :]), LA.norm(m[1, :, :]) + (array(3.74166, dtype=float32), array(11.225, dtype=float32)) + )pbdoc"); m.def( "norm", [](const array& a, @@ -166,5 +1268,126 @@ void init_linalg(py::module_& parent_module) { "axis"_a, "keepdims"_a = false, "stream"_a = none, - R"pbdoc()pbdoc"); + R"pbdoc( + Matrix or vector norm. + + This function is able to return matrix or vector norms, + depending on the value of the ``ord`` parameter. + + Parameters + ---------- + a : array_like + Input array. If `axis` is None, `a` must be 1-D or 2-D, unless `ord` + is None. If both `axis` and `ord` are None, the 2-norm of ``a.flatten`` will be returned. + ord : {non-zero int, inf, -inf, 'fro', 'nuc'}, optional + Order of the norm (see table under ``Notes``). inf means float(`inf`) object. The default is None. + axis : {None, int, 2-tuple of ints}, optional. + If `axis` is an integer, it specifies the axis of `a` along which to + compute the vector norms. If `axis` is a 2-tuple, it specifies the + axes that hold 2-D matrices, and the matrix norms of these matrices + are computed. If `axis` is None then either a vector norm (when `a` + is 1-D) or a matrix norm (when `a` is 2-D) is returned. The default + is None. + keepdims : bool, optional + If this is set to True, the axes which are normed over are left in the + result as dimensions with size one. With this option the result will + broadcast correctly against the original `a`. + + Returns + ------- + n : array + Norm of the matrix or vector(s). + + Notes + ----- + For values of ``ord < 1``, the result is, strictly speaking, not a + mathematical 'norm', but it may still be useful for various numerical + purposes. + + The following norms can be calculated: + + ===== ============================ ========================== + ord norm for matrices norm for vectors + ===== ============================ ========================== + None Frobenius norm 2-norm + 'fro' Frobenius norm -- + inf max(sum(abs(x), axis=1)) max(abs(x)) + -inf min(sum(abs(x), axis=1)) min(abs(x)) + 0 -- sum(x != 0) + 1 max(sum(abs(x), axis=0)) as below + -1 min(sum(abs(x), axis=0)) as below + 2 2-norm (largest sing. value) as below + -2 smallest singular value as below + other -- sum(abs(x)**ord)**(1./ord) + ===== ============================ ========================== + + Nuclear norm and norms based on singular values are not yet implemented. + + The Frobenius norm is given by [1]_: + + :math:`||A||_F = [\\sum_{i,j} abs(a_{i,j})^2]^{1/2}` + + The nuclear norm is the sum of the singular values. + + Both the Frobenius and nuclear norm orders are only defined for + matrices and raise a ValueError when ``a.ndim != 2``. + + References + ---------- + .. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*, + Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15 + + Examples + -------- + >>> import mlx.core as mx + >>> from mlx.core import linalg as LA + >>> a = mx.arange(9) - 4 + >>> a + array([-4, -3, -2, ..., 2, 3, 4], dtype=int32) + >>> b = a.reshape((3,3)) + >>> b + array([[-4, -3, -2], + [-1, 0, 1], + [ 2, 3, 4]], dtype=int32) + >>> LA.norm(a) + array(7.74597, dtype=float32) + >>> LA.norm(b) + array(7.74597, dtype=float32) + >>> LA.norm(b, 'fro') + array(7.74597, dtype=float32) + >>> LA.norm(a, float("inf")) + array(4, dtype=int32) + >>> LA.norm(b, float("inf")) + array(9, dtype=int32) + >>> LA.norm(a, -float("inf")) + array(0, dtype=int32) + >>> LA.norm(b, -float("inf")) + array(2, dtype=int32) + >>> LA.norm(a, 1) + array(20, dtype=int32) + >>> LA.norm(b, 1) + array(7, dtype=int32) + >>> LA.norm(a, -1) + array(0, dtype=float32) + >>> LA.norm(b, -1) + array(6, dtype=int32) + >>> LA.norm(a, 2) + array(7.74597, dtype=float32) + >>> LA.norm(a, 3) + array(5.84804, dtype=float32) + >>> LA.norm(a, -3) + array(0, dtype=float32) + >>> c = mx.array([[ 1, 2, 3], + ... [-1, 1, 4]]) + >>> LA.norm(c, axis=0) + array([1.41421, 2.23607, 5], dtype=float32) + >>> LA.norm(c, axis=1) + array([3.74166, 4.24264], dtype=float32) + >>> LA.norm(c, ord=1, axis=1) + array([6, 6], dtype=int32) + >>> m = mx.arange(8).reshape(2,2,2) + array([3.74166, 11.225], dtype=float32) + >>> LA.norm(m[0, :, :]), LA.norm(m[1, :, :]) + (array(3.74166, dtype=float32), array(11.225, dtype=float32)) + )pbdoc"); } From 145a4d143d08458be33ca24958d41b57f219af6b Mon Sep 17 00:00:00 2001 From: Gabrijel Boduljak Date: Fri, 22 Dec 2023 04:35:02 +0100 Subject: [PATCH 28/37] refactored mlx.linalg.norm bindings --- python/src/linalg.cpp | 1270 ++------------------------------------- python/src/overloaded.h | 8 + 2 files changed, 46 insertions(+), 1232 deletions(-) create mode 100644 python/src/overloaded.h diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp index 902b196a8..d96dd8a2d 100644 --- a/python/src/linalg.cpp +++ b/python/src/linalg.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include @@ -15,6 +16,7 @@ #include "mlx/utils.h" #include "python/src/load.h" +#include "python/src/overloaded.h" #include "python/src/utils.h" namespace py = pybind11; @@ -27,1245 +29,48 @@ void init_linalg(py::module_& parent_module) { auto m = parent_module.def_submodule("linalg", "mlx.core.linalg: Linear Algebra."); - m.def( - "norm", - [](const array& a, const bool keepdims, const StreamOrDevice stream) { - return norm(a, {}, keepdims, stream); - }, - "a"_a, - "keepdims"_a = false, - "stream"_a = none, - R"pbdoc( - Matrix or vector norm. - - This function is able to return matrix or vector norms, - depending on the value of the ``ord`` parameter. - - Parameters - ---------- - a : array_like - Input array. If `axis` is None, `a` must be 1-D or 2-D, unless `ord` - is None. If both `axis` and `ord` are None, the 2-norm of ``a.flatten`` will be returned. - ord : {non-zero int, inf, -inf, 'fro', 'nuc'}, optional - Order of the norm (see table under ``Notes``). inf means float(`inf`) object. The default is None. - axis : {None, int, 2-tuple of ints}, optional. - If `axis` is an integer, it specifies the axis of `a` along which to - compute the vector norms. If `axis` is a 2-tuple, it specifies the - axes that hold 2-D matrices, and the matrix norms of these matrices - are computed. If `axis` is None then either a vector norm (when `a` - is 1-D) or a matrix norm (when `a` is 2-D) is returned. The default - is None. - keepdims : bool, optional - If this is set to True, the axes which are normed over are left in the - result as dimensions with size one. With this option the result will - broadcast correctly against the original `a`. - - Returns - ------- - n : array - Norm of the matrix or vector(s). - - Notes - ----- - For values of ``ord < 1``, the result is, strictly speaking, not a - mathematical 'norm', but it may still be useful for various numerical - purposes. - - The following norms can be calculated: - - ===== ============================ ========================== - ord norm for matrices norm for vectors - ===== ============================ ========================== - None Frobenius norm 2-norm - 'fro' Frobenius norm -- - inf max(sum(abs(x), axis=1)) max(abs(x)) - -inf min(sum(abs(x), axis=1)) min(abs(x)) - 0 -- sum(x != 0) - 1 max(sum(abs(x), axis=0)) as below - -1 min(sum(abs(x), axis=0)) as below - 2 2-norm (largest sing. value) as below - -2 smallest singular value as below - other -- sum(abs(x)**ord)**(1./ord) - ===== ============================ ========================== - - Nuclear norm and norms based on singular values are not yet implemented. - - The Frobenius norm is given by [1]_: - - :math:`||A||_F = [\\sum_{i,j} abs(a_{i,j})^2]^{1/2}` - - The nuclear norm is the sum of the singular values. - - Both the Frobenius and nuclear norm orders are only defined for - matrices and raise a ValueError when ``a.ndim != 2``. - - References - ---------- - .. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*, - Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15 - - Examples - -------- - >>> import mlx.core as mx - >>> from mlx.core import linalg as LA - >>> a = mx.arange(9) - 4 - >>> a - array([-4, -3, -2, ..., 2, 3, 4], dtype=int32) - >>> b = a.reshape((3,3)) - >>> b - array([[-4, -3, -2], - [-1, 0, 1], - [ 2, 3, 4]], dtype=int32) - >>> LA.norm(a) - array(7.74597, dtype=float32) - >>> LA.norm(b) - array(7.74597, dtype=float32) - >>> LA.norm(b, 'fro') - array(7.74597, dtype=float32) - >>> LA.norm(a, float("inf")) - array(4, dtype=int32) - >>> LA.norm(b, float("inf")) - array(9, dtype=int32) - >>> LA.norm(a, -float("inf")) - array(0, dtype=int32) - >>> LA.norm(b, -float("inf")) - array(2, dtype=int32) - >>> LA.norm(a, 1) - array(20, dtype=int32) - >>> LA.norm(b, 1) - array(7, dtype=int32) - >>> LA.norm(a, -1) - array(0, dtype=float32) - >>> LA.norm(b, -1) - array(6, dtype=int32) - >>> LA.norm(a, 2) - array(7.74597, dtype=float32) - >>> LA.norm(a, 3) - array(5.84804, dtype=float32) - >>> LA.norm(a, -3) - array(0, dtype=float32) - >>> c = mx.array([[ 1, 2, 3], - ... [-1, 1, 4]]) - >>> LA.norm(c, axis=0) - array([1.41421, 2.23607, 5], dtype=float32) - >>> LA.norm(c, axis=1) - array([3.74166, 4.24264], dtype=float32) - >>> LA.norm(c, ord=1, axis=1) - array([6, 6], dtype=int32) - >>> m = mx.arange(8).reshape(2,2,2) - array([3.74166, 11.225], dtype=float32) - >>> LA.norm(m[0, :, :]), LA.norm(m[1, :, :]) - (array(3.74166, dtype=float32), array(11.225, dtype=float32)) - )pbdoc"); - m.def( "norm", [](const array& a, - const int ord, + const std::variant& ord, + const std::variant>& axis, const bool keepdims, const StreamOrDevice stream) { - return norm(a, ord, {}, keepdims, stream); + return std::visit( + overloaded{ + [&](const double p) { + if (std::isinf((float)p) || std::isinf(p)) { + if (p > 0) { + return norm( + a, + "inf", + get_reduce_axes(axis, a.ndim()), + keepdims, + stream); + } + return norm( + a, + "-inf", + get_reduce_axes(axis, a.ndim()), + keepdims, + stream); + } + return norm( + a, p, get_reduce_axes(axis, a.ndim()), keepdims, stream); + }, + [&](const std::string& p) { + return norm( + a, p, get_reduce_axes(axis, a.ndim()), keepdims, stream); + }, + [&](const std::monostate _) { + return norm( + a, get_reduce_axes(axis, a.ndim()), keepdims, stream); + }}, + ord); }, "a"_a, - "ord"_a, - "keepdims"_a = false, - "stream"_a = none, - R"pbdoc( - Matrix or vector norm. - - This function is able to return matrix or vector norms, - depending on the value of the ``ord`` parameter. - - Parameters - ---------- - a : array_like - Input array. If `axis` is None, `a` must be 1-D or 2-D, unless `ord` - is None. If both `axis` and `ord` are None, the 2-norm of ``a.flatten`` will be returned. - ord : {non-zero int, inf, -inf, 'fro', 'nuc'}, optional - Order of the norm (see table under ``Notes``). inf means float(`inf`) object. The default is None. - axis : {None, int, 2-tuple of ints}, optional. - If `axis` is an integer, it specifies the axis of `a` along which to - compute the vector norms. If `axis` is a 2-tuple, it specifies the - axes that hold 2-D matrices, and the matrix norms of these matrices - are computed. If `axis` is None then either a vector norm (when `a` - is 1-D) or a matrix norm (when `a` is 2-D) is returned. The default - is None. - keepdims : bool, optional - If this is set to True, the axes which are normed over are left in the - result as dimensions with size one. With this option the result will - broadcast correctly against the original `a`. - - Returns - ------- - n : array - Norm of the matrix or vector(s). - - Notes - ----- - For values of ``ord < 1``, the result is, strictly speaking, not a - mathematical 'norm', but it may still be useful for various numerical - purposes. - - The following norms can be calculated: - - ===== ============================ ========================== - ord norm for matrices norm for vectors - ===== ============================ ========================== - None Frobenius norm 2-norm - 'fro' Frobenius norm -- - inf max(sum(abs(x), axis=1)) max(abs(x)) - -inf min(sum(abs(x), axis=1)) min(abs(x)) - 0 -- sum(x != 0) - 1 max(sum(abs(x), axis=0)) as below - -1 min(sum(abs(x), axis=0)) as below - 2 2-norm (largest sing. value) as below - -2 smallest singular value as below - other -- sum(abs(x)**ord)**(1./ord) - ===== ============================ ========================== - - Nuclear norm and norms based on singular values are not yet implemented. - - The Frobenius norm is given by [1]_: - - :math:`||A||_F = [\\sum_{i,j} abs(a_{i,j})^2]^{1/2}` - - The nuclear norm is the sum of the singular values. - - Both the Frobenius and nuclear norm orders are only defined for - matrices and raise a ValueError when ``a.ndim != 2``. - - References - ---------- - .. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*, - Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15 - - Examples - -------- - >>> import mlx.core as mx - >>> from mlx.core import linalg as LA - >>> a = mx.arange(9) - 4 - >>> a - array([-4, -3, -2, ..., 2, 3, 4], dtype=int32) - >>> b = a.reshape((3,3)) - >>> b - array([[-4, -3, -2], - [-1, 0, 1], - [ 2, 3, 4]], dtype=int32) - >>> LA.norm(a) - array(7.74597, dtype=float32) - >>> LA.norm(b) - array(7.74597, dtype=float32) - >>> LA.norm(b, 'fro') - array(7.74597, dtype=float32) - >>> LA.norm(a, float("inf")) - array(4, dtype=int32) - >>> LA.norm(b, float("inf")) - array(9, dtype=int32) - >>> LA.norm(a, -float("inf")) - array(0, dtype=int32) - >>> LA.norm(b, -float("inf")) - array(2, dtype=int32) - >>> LA.norm(a, 1) - array(20, dtype=int32) - >>> LA.norm(b, 1) - array(7, dtype=int32) - >>> LA.norm(a, -1) - array(0, dtype=float32) - >>> LA.norm(b, -1) - array(6, dtype=int32) - >>> LA.norm(a, 2) - array(7.74597, dtype=float32) - >>> LA.norm(a, 3) - array(5.84804, dtype=float32) - >>> LA.norm(a, -3) - array(0, dtype=float32) - >>> c = mx.array([[ 1, 2, 3], - ... [-1, 1, 4]]) - >>> LA.norm(c, axis=0) - array([1.41421, 2.23607, 5], dtype=float32) - >>> LA.norm(c, axis=1) - array([3.74166, 4.24264], dtype=float32) - >>> LA.norm(c, ord=1, axis=1) - array([6, 6], dtype=int32) - >>> m = mx.arange(8).reshape(2,2,2) - array([3.74166, 11.225], dtype=float32) - >>> LA.norm(m[0, :, :]), LA.norm(m[1, :, :]) - (array(3.74166, dtype=float32), array(11.225, dtype=float32)) - )pbdoc"); - m.def( - "norm", - [](const array& a, - const int axis, - const bool keepdims, - const StreamOrDevice stream) { - return norm(a, {axis}, keepdims, stream); - }, - "a"_a, - "axis"_a, - "keepdims"_a = false, - "stream"_a = none, - R"pbdoc( - Matrix or vector norm. - - This function is able to return matrix or vector norms, - depending on the value of the ``ord`` parameter. - - Parameters - ---------- - a : array_like - Input array. If `axis` is None, `a` must be 1-D or 2-D, unless `ord` - is None. If both `axis` and `ord` are None, the 2-norm of ``a.flatten`` will be returned. - ord : {non-zero int, inf, -inf, 'fro', 'nuc'}, optional - Order of the norm (see table under ``Notes``). inf means float(`inf`) object. The default is None. - axis : {None, int, 2-tuple of ints}, optional. - If `axis` is an integer, it specifies the axis of `a` along which to - compute the vector norms. If `axis` is a 2-tuple, it specifies the - axes that hold 2-D matrices, and the matrix norms of these matrices - are computed. If `axis` is None then either a vector norm (when `a` - is 1-D) or a matrix norm (when `a` is 2-D) is returned. The default - is None. - keepdims : bool, optional - If this is set to True, the axes which are normed over are left in the - result as dimensions with size one. With this option the result will - broadcast correctly against the original `a`. - - Returns - ------- - n : array - Norm of the matrix or vector(s). - - Notes - ----- - For values of ``ord < 1``, the result is, strictly speaking, not a - mathematical 'norm', but it may still be useful for various numerical - purposes. - - The following norms can be calculated: - - ===== ============================ ========================== - ord norm for matrices norm for vectors - ===== ============================ ========================== - None Frobenius norm 2-norm - 'fro' Frobenius norm -- - inf max(sum(abs(x), axis=1)) max(abs(x)) - -inf min(sum(abs(x), axis=1)) min(abs(x)) - 0 -- sum(x != 0) - 1 max(sum(abs(x), axis=0)) as below - -1 min(sum(abs(x), axis=0)) as below - 2 2-norm (largest sing. value) as below - -2 smallest singular value as below - other -- sum(abs(x)**ord)**(1./ord) - ===== ============================ ========================== - - Nuclear norm and norms based on singular values are not yet implemented. - - The Frobenius norm is given by [1]_: - - :math:`||A||_F = [\\sum_{i,j} abs(a_{i,j})^2]^{1/2}` - - The nuclear norm is the sum of the singular values. - - Both the Frobenius and nuclear norm orders are only defined for - matrices and raise a ValueError when ``a.ndim != 2``. - - References - ---------- - .. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*, - Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15 - - Examples - -------- - >>> import mlx.core as mx - >>> from mlx.core import linalg as LA - >>> a = mx.arange(9) - 4 - >>> a - array([-4, -3, -2, ..., 2, 3, 4], dtype=int32) - >>> b = a.reshape((3,3)) - >>> b - array([[-4, -3, -2], - [-1, 0, 1], - [ 2, 3, 4]], dtype=int32) - >>> LA.norm(a) - array(7.74597, dtype=float32) - >>> LA.norm(b) - array(7.74597, dtype=float32) - >>> LA.norm(b, 'fro') - array(7.74597, dtype=float32) - >>> LA.norm(a, float("inf")) - array(4, dtype=int32) - >>> LA.norm(b, float("inf")) - array(9, dtype=int32) - >>> LA.norm(a, -float("inf")) - array(0, dtype=int32) - >>> LA.norm(b, -float("inf")) - array(2, dtype=int32) - >>> LA.norm(a, 1) - array(20, dtype=int32) - >>> LA.norm(b, 1) - array(7, dtype=int32) - >>> LA.norm(a, -1) - array(0, dtype=float32) - >>> LA.norm(b, -1) - array(6, dtype=int32) - >>> LA.norm(a, 2) - array(7.74597, dtype=float32) - >>> LA.norm(a, 3) - array(5.84804, dtype=float32) - >>> LA.norm(a, -3) - array(0, dtype=float32) - >>> c = mx.array([[ 1, 2, 3], - ... [-1, 1, 4]]) - >>> LA.norm(c, axis=0) - array([1.41421, 2.23607, 5], dtype=float32) - >>> LA.norm(c, axis=1) - array([3.74166, 4.24264], dtype=float32) - >>> LA.norm(c, ord=1, axis=1) - array([6, 6], dtype=int32) - >>> m = mx.arange(8).reshape(2,2,2) - array([3.74166, 11.225], dtype=float32) - >>> LA.norm(m[0, :, :]), LA.norm(m[1, :, :]) - (array(3.74166, dtype=float32), array(11.225, dtype=float32)) - )pbdoc"); - m.def( - "norm", - [](const array& a, - const std::vector& axis, - const bool keepdims, - const StreamOrDevice stream) { - return norm(a, axis, keepdims, stream); - }, - "a"_a, - "axis"_a, - "keepdims"_a = false, - "stream"_a = none, - R"pbdoc( - Matrix or vector norm. - - This function is able to return matrix or vector norms, - depending on the value of the ``ord`` parameter. - - Parameters - ---------- - a : array_like - Input array. If `axis` is None, `a` must be 1-D or 2-D, unless `ord` - is None. If both `axis` and `ord` are None, the 2-norm of ``a.flatten`` will be returned. - ord : {non-zero int, inf, -inf, 'fro', 'nuc'}, optional - Order of the norm (see table under ``Notes``). inf means float(`inf`) object. The default is None. - axis : {None, int, 2-tuple of ints}, optional. - If `axis` is an integer, it specifies the axis of `a` along which to - compute the vector norms. If `axis` is a 2-tuple, it specifies the - axes that hold 2-D matrices, and the matrix norms of these matrices - are computed. If `axis` is None then either a vector norm (when `a` - is 1-D) or a matrix norm (when `a` is 2-D) is returned. The default - is None. - keepdims : bool, optional - If this is set to True, the axes which are normed over are left in the - result as dimensions with size one. With this option the result will - broadcast correctly against the original `a`. - - Returns - ------- - n : array - Norm of the matrix or vector(s). - - Notes - ----- - For values of ``ord < 1``, the result is, strictly speaking, not a - mathematical 'norm', but it may still be useful for various numerical - purposes. - - The following norms can be calculated: - - ===== ============================ ========================== - ord norm for matrices norm for vectors - ===== ============================ ========================== - None Frobenius norm 2-norm - 'fro' Frobenius norm -- - inf max(sum(abs(x), axis=1)) max(abs(x)) - -inf min(sum(abs(x), axis=1)) min(abs(x)) - 0 -- sum(x != 0) - 1 max(sum(abs(x), axis=0)) as below - -1 min(sum(abs(x), axis=0)) as below - 2 2-norm (largest sing. value) as below - -2 smallest singular value as below - other -- sum(abs(x)**ord)**(1./ord) - ===== ============================ ========================== - - Nuclear norm and norms based on singular values are not yet implemented. - - The Frobenius norm is given by [1]_: - - :math:`||A||_F = [\\sum_{i,j} abs(a_{i,j})^2]^{1/2}` - - The nuclear norm is the sum of the singular values. - - Both the Frobenius and nuclear norm orders are only defined for - matrices and raise a ValueError when ``a.ndim != 2``. - - References - ---------- - .. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*, - Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15 - - Examples - -------- - >>> import mlx.core as mx - >>> from mlx.core import linalg as LA - >>> a = mx.arange(9) - 4 - >>> a - array([-4, -3, -2, ..., 2, 3, 4], dtype=int32) - >>> b = a.reshape((3,3)) - >>> b - array([[-4, -3, -2], - [-1, 0, 1], - [ 2, 3, 4]], dtype=int32) - >>> LA.norm(a) - array(7.74597, dtype=float32) - >>> LA.norm(b) - array(7.74597, dtype=float32) - >>> LA.norm(b, 'fro') - array(7.74597, dtype=float32) - >>> LA.norm(a, float("inf")) - array(4, dtype=int32) - >>> LA.norm(b, float("inf")) - array(9, dtype=int32) - >>> LA.norm(a, -float("inf")) - array(0, dtype=int32) - >>> LA.norm(b, -float("inf")) - array(2, dtype=int32) - >>> LA.norm(a, 1) - array(20, dtype=int32) - >>> LA.norm(b, 1) - array(7, dtype=int32) - >>> LA.norm(a, -1) - array(0, dtype=float32) - >>> LA.norm(b, -1) - array(6, dtype=int32) - >>> LA.norm(a, 2) - array(7.74597, dtype=float32) - >>> LA.norm(a, 3) - array(5.84804, dtype=float32) - >>> LA.norm(a, -3) - array(0, dtype=float32) - >>> c = mx.array([[ 1, 2, 3], - ... [-1, 1, 4]]) - >>> LA.norm(c, axis=0) - array([1.41421, 2.23607, 5], dtype=float32) - >>> LA.norm(c, axis=1) - array([3.74166, 4.24264], dtype=float32) - >>> LA.norm(c, ord=1, axis=1) - array([6, 6], dtype=int32) - >>> m = mx.arange(8).reshape(2,2,2) - array([3.74166, 11.225], dtype=float32) - >>> LA.norm(m[0, :, :]), LA.norm(m[1, :, :]) - (array(3.74166, dtype=float32), array(11.225, dtype=float32)) - )pbdoc"); - m.def( - "norm", - [](const array& a, - const double ord, - const bool keepdims, - const StreamOrDevice stream) { - if (std::isinf((float)ord) || std::isinf(ord)) { - if (ord > 0) - return norm(a, "inf", {}, keepdims, stream); - else - return norm(a, "-inf", {}, keepdims, stream); - } - return norm(a, ord, {}, keepdims, stream); - }, - "a"_a, - "ord"_a, - "keepdims"_a = false, - "stream"_a = none, - R"pbdoc( - Matrix or vector norm. - - This function is able to return matrix or vector norms, - depending on the value of the ``ord`` parameter. - - Parameters - ---------- - a : array_like - Input array. If `axis` is None, `a` must be 1-D or 2-D, unless `ord` - is None. If both `axis` and `ord` are None, the 2-norm of ``a.flatten`` will be returned. - ord : {non-zero int, inf, -inf, 'fro', 'nuc'}, optional - Order of the norm (see table under ``Notes``). inf means float(`inf`) object. The default is None. - axis : {None, int, 2-tuple of ints}, optional. - If `axis` is an integer, it specifies the axis of `a` along which to - compute the vector norms. If `axis` is a 2-tuple, it specifies the - axes that hold 2-D matrices, and the matrix norms of these matrices - are computed. If `axis` is None then either a vector norm (when `a` - is 1-D) or a matrix norm (when `a` is 2-D) is returned. The default - is None. - keepdims : bool, optional - If this is set to True, the axes which are normed over are left in the - result as dimensions with size one. With this option the result will - broadcast correctly against the original `a`. - - Returns - ------- - n : array - Norm of the matrix or vector(s). - - Notes - ----- - For values of ``ord < 1``, the result is, strictly speaking, not a - mathematical 'norm', but it may still be useful for various numerical - purposes. - - The following norms can be calculated: - - ===== ============================ ========================== - ord norm for matrices norm for vectors - ===== ============================ ========================== - None Frobenius norm 2-norm - 'fro' Frobenius norm -- - inf max(sum(abs(x), axis=1)) max(abs(x)) - -inf min(sum(abs(x), axis=1)) min(abs(x)) - 0 -- sum(x != 0) - 1 max(sum(abs(x), axis=0)) as below - -1 min(sum(abs(x), axis=0)) as below - 2 2-norm (largest sing. value) as below - -2 smallest singular value as below - other -- sum(abs(x)**ord)**(1./ord) - ===== ============================ ========================== - - Nuclear norm and norms based on singular values are not yet implemented. - - The Frobenius norm is given by [1]_: - - :math:`||A||_F = [\\sum_{i,j} abs(a_{i,j})^2]^{1/2}` - - The nuclear norm is the sum of the singular values. - - Both the Frobenius and nuclear norm orders are only defined for - matrices and raise a ValueError when ``a.ndim != 2``. - - References - ---------- - .. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*, - Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15 - - Examples - -------- - >>> import mlx.core as mx - >>> from mlx.core import linalg as LA - >>> a = mx.arange(9) - 4 - >>> a - array([-4, -3, -2, ..., 2, 3, 4], dtype=int32) - >>> b = a.reshape((3,3)) - >>> b - array([[-4, -3, -2], - [-1, 0, 1], - [ 2, 3, 4]], dtype=int32) - >>> LA.norm(a) - array(7.74597, dtype=float32) - >>> LA.norm(b) - array(7.74597, dtype=float32) - >>> LA.norm(b, 'fro') - array(7.74597, dtype=float32) - >>> LA.norm(a, float("inf")) - array(4, dtype=int32) - >>> LA.norm(b, float("inf")) - array(9, dtype=int32) - >>> LA.norm(a, -float("inf")) - array(0, dtype=int32) - >>> LA.norm(b, -float("inf")) - array(2, dtype=int32) - >>> LA.norm(a, 1) - array(20, dtype=int32) - >>> LA.norm(b, 1) - array(7, dtype=int32) - >>> LA.norm(a, -1) - array(0, dtype=float32) - >>> LA.norm(b, -1) - array(6, dtype=int32) - >>> LA.norm(a, 2) - array(7.74597, dtype=float32) - >>> LA.norm(a, 3) - array(5.84804, dtype=float32) - >>> LA.norm(a, -3) - array(0, dtype=float32) - >>> c = mx.array([[ 1, 2, 3], - ... [-1, 1, 4]]) - >>> LA.norm(c, axis=0) - array([1.41421, 2.23607, 5], dtype=float32) - >>> LA.norm(c, axis=1) - array([3.74166, 4.24264], dtype=float32) - >>> LA.norm(c, ord=1, axis=1) - array([6, 6], dtype=int32) - >>> m = mx.arange(8).reshape(2,2,2) - array([3.74166, 11.225], dtype=float32) - >>> LA.norm(m[0, :, :]), LA.norm(m[1, :, :]) - (array(3.74166, dtype=float32), array(11.225, dtype=float32)) - )pbdoc"); - m.def( - "norm", - [](const array& a, - const double ord, - const int axis, - const bool keepdims, - const StreamOrDevice stream) { - if (std::isinf((float)ord) || std::isinf(ord)) { - if (ord > 0) - return norm(a, "inf", {axis}, keepdims, stream); - else - return norm(a, "-inf", {axis}, keepdims, stream); - } - return norm(a, ord, {axis}, keepdims, stream); - }, - "a"_a, - "ord"_a, - "axis"_a, - "keepdims"_a = false, - "stream"_a = none, - R"pbdoc( - Matrix or vector norm. - - This function is able to return matrix or vector norms, - depending on the value of the ``ord`` parameter. - - Parameters - ---------- - a : array_like - Input array. If `axis` is None, `a` must be 1-D or 2-D, unless `ord` - is None. If both `axis` and `ord` are None, the 2-norm of ``a.flatten`` will be returned. - ord : {non-zero int, inf, -inf, 'fro', 'nuc'}, optional - Order of the norm (see table under ``Notes``). inf means float(`inf`) object. The default is None. - axis : {None, int, 2-tuple of ints}, optional. - If `axis` is an integer, it specifies the axis of `a` along which to - compute the vector norms. If `axis` is a 2-tuple, it specifies the - axes that hold 2-D matrices, and the matrix norms of these matrices - are computed. If `axis` is None then either a vector norm (when `a` - is 1-D) or a matrix norm (when `a` is 2-D) is returned. The default - is None. - keepdims : bool, optional - If this is set to True, the axes which are normed over are left in the - result as dimensions with size one. With this option the result will - broadcast correctly against the original `a`. - - Returns - ------- - n : array - Norm of the matrix or vector(s). - - Notes - ----- - For values of ``ord < 1``, the result is, strictly speaking, not a - mathematical 'norm', but it may still be useful for various numerical - purposes. - - The following norms can be calculated: - - ===== ============================ ========================== - ord norm for matrices norm for vectors - ===== ============================ ========================== - None Frobenius norm 2-norm - 'fro' Frobenius norm -- - inf max(sum(abs(x), axis=1)) max(abs(x)) - -inf min(sum(abs(x), axis=1)) min(abs(x)) - 0 -- sum(x != 0) - 1 max(sum(abs(x), axis=0)) as below - -1 min(sum(abs(x), axis=0)) as below - 2 2-norm (largest sing. value) as below - -2 smallest singular value as below - other -- sum(abs(x)**ord)**(1./ord) - ===== ============================ ========================== - - Nuclear norm and norms based on singular values are not yet implemented. - - The Frobenius norm is given by [1]_: - - :math:`||A||_F = [\\sum_{i,j} abs(a_{i,j})^2]^{1/2}` - - The nuclear norm is the sum of the singular values. - - Both the Frobenius and nuclear norm orders are only defined for - matrices and raise a ValueError when ``a.ndim != 2``. - - References - ---------- - .. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*, - Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15 - - Examples - -------- - >>> import mlx.core as mx - >>> from mlx.core import linalg as LA - >>> a = mx.arange(9) - 4 - >>> a - array([-4, -3, -2, ..., 2, 3, 4], dtype=int32) - >>> b = a.reshape((3,3)) - >>> b - array([[-4, -3, -2], - [-1, 0, 1], - [ 2, 3, 4]], dtype=int32) - >>> LA.norm(a) - array(7.74597, dtype=float32) - >>> LA.norm(b) - array(7.74597, dtype=float32) - >>> LA.norm(b, 'fro') - array(7.74597, dtype=float32) - >>> LA.norm(a, float("inf")) - array(4, dtype=int32) - >>> LA.norm(b, float("inf")) - array(9, dtype=int32) - >>> LA.norm(a, -float("inf")) - array(0, dtype=int32) - >>> LA.norm(b, -float("inf")) - array(2, dtype=int32) - >>> LA.norm(a, 1) - array(20, dtype=int32) - >>> LA.norm(b, 1) - array(7, dtype=int32) - >>> LA.norm(a, -1) - array(0, dtype=float32) - >>> LA.norm(b, -1) - array(6, dtype=int32) - >>> LA.norm(a, 2) - array(7.74597, dtype=float32) - >>> LA.norm(a, 3) - array(5.84804, dtype=float32) - >>> LA.norm(a, -3) - array(0, dtype=float32) - >>> c = mx.array([[ 1, 2, 3], - ... [-1, 1, 4]]) - >>> LA.norm(c, axis=0) - array([1.41421, 2.23607, 5], dtype=float32) - >>> LA.norm(c, axis=1) - array([3.74166, 4.24264], dtype=float32) - >>> LA.norm(c, ord=1, axis=1) - array([6, 6], dtype=int32) - >>> m = mx.arange(8).reshape(2,2,2) - array([3.74166, 11.225], dtype=float32) - >>> LA.norm(m[0, :, :]), LA.norm(m[1, :, :]) - (array(3.74166, dtype=float32), array(11.225, dtype=float32)) - )pbdoc"); - m.def( - "norm", - [](const array& a, - const double ord, - const std::vector& axis, - const bool keepdims, - const StreamOrDevice stream) { - if (std::isinf((float)ord) || std::isinf(ord)) { - if (ord > 0) - return norm(a, "inf", axis, keepdims, stream); - else - return norm(a, "-inf", axis, keepdims, stream); - } - return norm(a, ord, axis, keepdims, stream); - }, - "a"_a, - "ord"_a, - "axis"_a, - "keepdims"_a = false, - "stream"_a = none, - R"pbdoc( - Matrix or vector norm. - - This function is able to return matrix or vector norms, - depending on the value of the ``ord`` parameter. - - Parameters - ---------- - a : array_like - Input array. If `axis` is None, `a` must be 1-D or 2-D, unless `ord` - is None. If both `axis` and `ord` are None, the 2-norm of ``a.flatten`` will be returned. - ord : {non-zero int, inf, -inf, 'fro', 'nuc'}, optional - Order of the norm (see table under ``Notes``). inf means float(`inf`) object. The default is None. - axis : {None, int, 2-tuple of ints}, optional. - If `axis` is an integer, it specifies the axis of `a` along which to - compute the vector norms. If `axis` is a 2-tuple, it specifies the - axes that hold 2-D matrices, and the matrix norms of these matrices - are computed. If `axis` is None then either a vector norm (when `a` - is 1-D) or a matrix norm (when `a` is 2-D) is returned. The default - is None. - keepdims : bool, optional - If this is set to True, the axes which are normed over are left in the - result as dimensions with size one. With this option the result will - broadcast correctly against the original `a`. - - Returns - ------- - n : array - Norm of the matrix or vector(s). - - Notes - ----- - For values of ``ord < 1``, the result is, strictly speaking, not a - mathematical 'norm', but it may still be useful for various numerical - purposes. - - The following norms can be calculated: - - ===== ============================ ========================== - ord norm for matrices norm for vectors - ===== ============================ ========================== - None Frobenius norm 2-norm - 'fro' Frobenius norm -- - inf max(sum(abs(x), axis=1)) max(abs(x)) - -inf min(sum(abs(x), axis=1)) min(abs(x)) - 0 -- sum(x != 0) - 1 max(sum(abs(x), axis=0)) as below - -1 min(sum(abs(x), axis=0)) as below - 2 2-norm (largest sing. value) as below - -2 smallest singular value as below - other -- sum(abs(x)**ord)**(1./ord) - ===== ============================ ========================== - - Nuclear norm and norms based on singular values are not yet implemented. - - The Frobenius norm is given by [1]_: - - :math:`||A||_F = [\\sum_{i,j} abs(a_{i,j})^2]^{1/2}` - - The nuclear norm is the sum of the singular values. - - Both the Frobenius and nuclear norm orders are only defined for - matrices and raise a ValueError when ``a.ndim != 2``. - - References - ---------- - .. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*, - Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15 - - Examples - -------- - >>> import mlx.core as mx - >>> from mlx.core import linalg as LA - >>> a = mx.arange(9) - 4 - >>> a - array([-4, -3, -2, ..., 2, 3, 4], dtype=int32) - >>> b = a.reshape((3,3)) - >>> b - array([[-4, -3, -2], - [-1, 0, 1], - [ 2, 3, 4]], dtype=int32) - >>> LA.norm(a) - array(7.74597, dtype=float32) - >>> LA.norm(b) - array(7.74597, dtype=float32) - >>> LA.norm(b, 'fro') - array(7.74597, dtype=float32) - >>> LA.norm(a, float("inf")) - array(4, dtype=int32) - >>> LA.norm(b, float("inf")) - array(9, dtype=int32) - >>> LA.norm(a, -float("inf")) - array(0, dtype=int32) - >>> LA.norm(b, -float("inf")) - array(2, dtype=int32) - >>> LA.norm(a, 1) - array(20, dtype=int32) - >>> LA.norm(b, 1) - array(7, dtype=int32) - >>> LA.norm(a, -1) - array(0, dtype=float32) - >>> LA.norm(b, -1) - array(6, dtype=int32) - >>> LA.norm(a, 2) - array(7.74597, dtype=float32) - >>> LA.norm(a, 3) - array(5.84804, dtype=float32) - >>> LA.norm(a, -3) - array(0, dtype=float32) - >>> c = mx.array([[ 1, 2, 3], - ... [-1, 1, 4]]) - >>> LA.norm(c, axis=0) - array([1.41421, 2.23607, 5], dtype=float32) - >>> LA.norm(c, axis=1) - array([3.74166, 4.24264], dtype=float32) - >>> LA.norm(c, ord=1, axis=1) - array([6, 6], dtype=int32) - >>> m = mx.arange(8).reshape(2,2,2) - array([3.74166, 11.225], dtype=float32) - >>> LA.norm(m[0, :, :]), LA.norm(m[1, :, :]) - (array(3.74166, dtype=float32), array(11.225, dtype=float32)) - )pbdoc"); - m.def( - "norm", - [](const array& a, - const std::string& ord, - const bool keepdims, - const StreamOrDevice stream) { - return norm(a, ord, {}, keepdims, stream); - }, - "a"_a, - "ord"_a, - "keepdims"_a = false, - "stream"_a = none, - R"pbdoc( - Matrix or vector norm. - - This function is able to return matrix or vector norms, - depending on the value of the ``ord`` parameter. - - Parameters - ---------- - a : array_like - Input array. If `axis` is None, `a` must be 1-D or 2-D, unless `ord` - is None. If both `axis` and `ord` are None, the 2-norm of ``a.flatten`` will be returned. - ord : {non-zero int, inf, -inf, 'fro', 'nuc'}, optional - Order of the norm (see table under ``Notes``). inf means float(`inf`) object. The default is None. - axis : {None, int, 2-tuple of ints}, optional. - If `axis` is an integer, it specifies the axis of `a` along which to - compute the vector norms. If `axis` is a 2-tuple, it specifies the - axes that hold 2-D matrices, and the matrix norms of these matrices - are computed. If `axis` is None then either a vector norm (when `a` - is 1-D) or a matrix norm (when `a` is 2-D) is returned. The default - is None. - keepdims : bool, optional - If this is set to True, the axes which are normed over are left in the - result as dimensions with size one. With this option the result will - broadcast correctly against the original `a`. - - Returns - ------- - n : array - Norm of the matrix or vector(s). - - Notes - ----- - For values of ``ord < 1``, the result is, strictly speaking, not a - mathematical 'norm', but it may still be useful for various numerical - purposes. - - The following norms can be calculated: - - ===== ============================ ========================== - ord norm for matrices norm for vectors - ===== ============================ ========================== - None Frobenius norm 2-norm - 'fro' Frobenius norm -- - inf max(sum(abs(x), axis=1)) max(abs(x)) - -inf min(sum(abs(x), axis=1)) min(abs(x)) - 0 -- sum(x != 0) - 1 max(sum(abs(x), axis=0)) as below - -1 min(sum(abs(x), axis=0)) as below - 2 2-norm (largest sing. value) as below - -2 smallest singular value as below - other -- sum(abs(x)**ord)**(1./ord) - ===== ============================ ========================== - - Nuclear norm and norms based on singular values are not yet implemented. - - The Frobenius norm is given by [1]_: - - :math:`||A||_F = [\\sum_{i,j} abs(a_{i,j})^2]^{1/2}` - - The nuclear norm is the sum of the singular values. - - Both the Frobenius and nuclear norm orders are only defined for - matrices and raise a ValueError when ``a.ndim != 2``. - - References - ---------- - .. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*, - Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15 - - Examples - -------- - >>> import mlx.core as mx - >>> from mlx.core import linalg as LA - >>> a = mx.arange(9) - 4 - >>> a - array([-4, -3, -2, ..., 2, 3, 4], dtype=int32) - >>> b = a.reshape((3,3)) - >>> b - array([[-4, -3, -2], - [-1, 0, 1], - [ 2, 3, 4]], dtype=int32) - >>> LA.norm(a) - array(7.74597, dtype=float32) - >>> LA.norm(b) - array(7.74597, dtype=float32) - >>> LA.norm(b, 'fro') - array(7.74597, dtype=float32) - >>> LA.norm(a, float("inf")) - array(4, dtype=int32) - >>> LA.norm(b, float("inf")) - array(9, dtype=int32) - >>> LA.norm(a, -float("inf")) - array(0, dtype=int32) - >>> LA.norm(b, -float("inf")) - array(2, dtype=int32) - >>> LA.norm(a, 1) - array(20, dtype=int32) - >>> LA.norm(b, 1) - array(7, dtype=int32) - >>> LA.norm(a, -1) - array(0, dtype=float32) - >>> LA.norm(b, -1) - array(6, dtype=int32) - >>> LA.norm(a, 2) - array(7.74597, dtype=float32) - >>> LA.norm(a, 3) - array(5.84804, dtype=float32) - >>> LA.norm(a, -3) - array(0, dtype=float32) - >>> c = mx.array([[ 1, 2, 3], - ... [-1, 1, 4]]) - >>> LA.norm(c, axis=0) - array([1.41421, 2.23607, 5], dtype=float32) - >>> LA.norm(c, axis=1) - array([3.74166, 4.24264], dtype=float32) - >>> LA.norm(c, ord=1, axis=1) - array([6, 6], dtype=int32) - >>> m = mx.arange(8).reshape(2,2,2) - array([3.74166, 11.225], dtype=float32) - >>> LA.norm(m[0, :, :]), LA.norm(m[1, :, :]) - (array(3.74166, dtype=float32), array(11.225, dtype=float32)) - )pbdoc"); - m.def( - "norm", - [](const array& a, - const std::string& ord, - const int axis, - const bool keepdims, - const StreamOrDevice stream) { - return norm(a, ord, {axis}, keepdims, stream); - }, - "a"_a, - "ord"_a, - "axis"_a, - "keepdims"_a = false, - "stream"_a = none, - R"pbdoc( - Matrix or vector norm. - - This function is able to return matrix or vector norms, - depending on the value of the ``ord`` parameter. - - Parameters - ---------- - a : array_like - Input array. If `axis` is None, `a` must be 1-D or 2-D, unless `ord` - is None. If both `axis` and `ord` are None, the 2-norm of ``a.flatten`` will be returned. - ord : {non-zero int, inf, -inf, 'fro', 'nuc'}, optional - Order of the norm (see table under ``Notes``). inf means float(`inf`) object. The default is None. - axis : {None, int, 2-tuple of ints}, optional. - If `axis` is an integer, it specifies the axis of `a` along which to - compute the vector norms. If `axis` is a 2-tuple, it specifies the - axes that hold 2-D matrices, and the matrix norms of these matrices - are computed. If `axis` is None then either a vector norm (when `a` - is 1-D) or a matrix norm (when `a` is 2-D) is returned. The default - is None. - keepdims : bool, optional - If this is set to True, the axes which are normed over are left in the - result as dimensions with size one. With this option the result will - broadcast correctly against the original `a`. - - Returns - ------- - n : array - Norm of the matrix or vector(s). - - Notes - ----- - For values of ``ord < 1``, the result is, strictly speaking, not a - mathematical 'norm', but it may still be useful for various numerical - purposes. - - The following norms can be calculated: - - ===== ============================ ========================== - ord norm for matrices norm for vectors - ===== ============================ ========================== - None Frobenius norm 2-norm - 'fro' Frobenius norm -- - inf max(sum(abs(x), axis=1)) max(abs(x)) - -inf min(sum(abs(x), axis=1)) min(abs(x)) - 0 -- sum(x != 0) - 1 max(sum(abs(x), axis=0)) as below - -1 min(sum(abs(x), axis=0)) as below - 2 2-norm (largest sing. value) as below - -2 smallest singular value as below - other -- sum(abs(x)**ord)**(1./ord) - ===== ============================ ========================== - - Nuclear norm and norms based on singular values are not yet implemented. - - The Frobenius norm is given by [1]_: - - :math:`||A||_F = [\\sum_{i,j} abs(a_{i,j})^2]^{1/2}` - - The nuclear norm is the sum of the singular values. - - Both the Frobenius and nuclear norm orders are only defined for - matrices and raise a ValueError when ``a.ndim != 2``. - - References - ---------- - .. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*, - Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15 - - Examples - -------- - >>> import mlx.core as mx - >>> from mlx.core import linalg as LA - >>> a = mx.arange(9) - 4 - >>> a - array([-4, -3, -2, ..., 2, 3, 4], dtype=int32) - >>> b = a.reshape((3,3)) - >>> b - array([[-4, -3, -2], - [-1, 0, 1], - [ 2, 3, 4]], dtype=int32) - >>> LA.norm(a) - array(7.74597, dtype=float32) - >>> LA.norm(b) - array(7.74597, dtype=float32) - >>> LA.norm(b, 'fro') - array(7.74597, dtype=float32) - >>> LA.norm(a, float("inf")) - array(4, dtype=int32) - >>> LA.norm(b, float("inf")) - array(9, dtype=int32) - >>> LA.norm(a, -float("inf")) - array(0, dtype=int32) - >>> LA.norm(b, -float("inf")) - array(2, dtype=int32) - >>> LA.norm(a, 1) - array(20, dtype=int32) - >>> LA.norm(b, 1) - array(7, dtype=int32) - >>> LA.norm(a, -1) - array(0, dtype=float32) - >>> LA.norm(b, -1) - array(6, dtype=int32) - >>> LA.norm(a, 2) - array(7.74597, dtype=float32) - >>> LA.norm(a, 3) - array(5.84804, dtype=float32) - >>> LA.norm(a, -3) - array(0, dtype=float32) - >>> c = mx.array([[ 1, 2, 3], - ... [-1, 1, 4]]) - >>> LA.norm(c, axis=0) - array([1.41421, 2.23607, 5], dtype=float32) - >>> LA.norm(c, axis=1) - array([3.74166, 4.24264], dtype=float32) - >>> LA.norm(c, ord=1, axis=1) - array([6, 6], dtype=int32) - >>> m = mx.arange(8).reshape(2,2,2) - array([3.74166, 11.225], dtype=float32) - >>> LA.norm(m[0, :, :]), LA.norm(m[1, :, :]) - (array(3.74166, dtype=float32), array(11.225, dtype=float32)) - )pbdoc"); - m.def( - "norm", - [](const array& a, - const std::string& ord, - const std::vector& axis, - const bool keepdims, - const StreamOrDevice stream) { - return norm(a, ord, axis, keepdims, stream); - }, - "a"_a, - "ord"_a, - "axis"_a, + "ord"_a = none, + "axis"_a = none, "keepdims"_a = false, "stream"_a = none, R"pbdoc( @@ -1386,6 +191,7 @@ void init_linalg(py::module_& parent_module) { >>> LA.norm(c, ord=1, axis=1) array([6, 6], dtype=int32) >>> m = mx.arange(8).reshape(2,2,2) + >>> LA.norm(m, axis=(1,2)) array([3.74166, 11.225], dtype=float32) >>> LA.norm(m[0, :, :]), LA.norm(m[1, :, :]) (array(3.74166, dtype=float32), array(11.225, dtype=float32)) diff --git a/python/src/overloaded.h b/python/src/overloaded.h new file mode 100644 index 000000000..204a3178a --- /dev/null +++ b/python/src/overloaded.h @@ -0,0 +1,8 @@ +// Copyright © 2023 Apple Inc. + +template +struct overloaded : Ts... { + using Ts::operator()...; +}; +template +overloaded(Ts...) -> overloaded; \ No newline at end of file From f82ab0eec9cf78921c41934b7d2d8cb644049c07 Mon Sep 17 00:00:00 2001 From: Gabrijel Boduljak Date: Fri, 22 Dec 2023 04:58:51 +0100 Subject: [PATCH 29/37] reused existing util for implementation of linalg.norm --- mlx/linalg.cpp | 7 +++---- mlx/utils.h | 19 +++++++++++++++---- python/src/fft.cpp | 1 + python/src/utils.h | 15 --------------- 4 files changed, 19 insertions(+), 23 deletions(-) diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index 614e6f79c..f541c6214 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -96,8 +96,7 @@ array norm( if (num_axes == 0 || num_axes == 1 || num_axes == 2) return sqrt( sum(abs(a, s) * abs(a, s), - num_axes ? axis - : get_shape_reducing_over_all_axes(a.shape().size()), + num_axes ? axis : get_reduce_axes({}, a.ndim()), keepdims, s), s); @@ -116,7 +115,7 @@ array norm( std::vector ax = axis; if (axis.empty()) - ax = get_shape_reducing_over_all_axes(a.ndim()); + ax = get_reduce_axes({}, a.ndim()); else ax = normalize_axes(ax, a.ndim()); @@ -140,7 +139,7 @@ array norm( std::vector ax = axis; if (axis.empty()) - ax = get_shape_reducing_over_all_axes(a.ndim()); + ax = get_reduce_axes({}, a.ndim()); else ax = normalize_axes(ax, a.ndim()); diff --git a/mlx/utils.h b/mlx/utils.h index 1158b7c42..0b0ae9e93 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -2,6 +2,7 @@ #pragma once +#include #include "array.h" #include "device.h" #include "dtype.h" @@ -42,8 +43,18 @@ inline std::ostream& operator<<(std::ostream& os, const float16_t& v) { inline std::ostream& operator<<(std::ostream& os, const bfloat16_t& v) { return os << static_cast(v); } -/** - * Returns the axes vector [0, 1, ... ndim). - */ -std::vector get_shape_reducing_over_all_axes(int ndim); + +using IntOrVec = std::variant>; +inline std::vector get_reduce_axes(const IntOrVec& v, int dims) { + std::vector axes; + if (std::holds_alternative(v)) { + axes.resize(dims); + std::iota(axes.begin(), axes.end(), 0); + } else if (auto pv = std::get_if(&v); pv) { + axes.push_back(*pv); + } else { + axes = std::get>(v); + } + return axes; +} } // namespace mlx::core diff --git a/python/src/fft.cpp b/python/src/fft.cpp index 42ad37633..6b3739ae6 100644 --- a/python/src/fft.cpp +++ b/python/src/fft.cpp @@ -7,6 +7,7 @@ #include "mlx/fft.h" #include "mlx/ops.h" +#include "mlx/utils.h" namespace py = pybind11; using namespace py::literals; diff --git a/python/src/utils.h b/python/src/utils.h index 5ac878979..9751b2d6e 100644 --- a/python/src/utils.h +++ b/python/src/utils.h @@ -1,7 +1,6 @@ // Copyright © 2023 Apple Inc. #pragma once -#include #include #include @@ -14,24 +13,10 @@ namespace py = pybind11; using namespace mlx::core; -using IntOrVec = std::variant>; using ScalarOrArray = std:: variant, py::object>; static constexpr std::monostate none{}; -inline std::vector get_reduce_axes(const IntOrVec& v, int dims) { - std::vector axes; - if (std::holds_alternative(v)) { - axes.resize(dims); - std::iota(axes.begin(), axes.end(), 0); - } else if (auto pv = std::get_if(&v); pv) { - axes.push_back(*pv); - } else { - axes = std::get>(v); - } - return axes; -} - inline array to_array_with_accessor(py::object obj) { if (py::hasattr(obj, "__mlx_array__")) { return obj.attr("__mlx_array__")().cast(); From 5a184d5b5dc2de98fe489e92a9b5b601be2e0bc0 Mon Sep 17 00:00:00 2001 From: Gabrijel Boduljak Date: Fri, 22 Dec 2023 05:34:06 +0100 Subject: [PATCH 30/37] more tests --- python/tests/test_linalg.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index 1969e1028..6c6d34699 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -42,11 +42,27 @@ class TestLinalg(mlx_tests.MLXTestCase): ) assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) - # Test when no axes and no ords are provided - for keepdims in [True, False]: - out_np = np.linalg.norm(x_np, keepdims=keepdims) - out_mx = mx.linalg.norm(x_mx, keepdims=keepdims) - assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) + # Test only axis provided + for shape in [(3,), (2, 3), (2, 3, 3)]: + x_mx = mx.arange(math.prod(shape)).reshape(shape) + x_np = np.arange(math.prod(shape)).reshape(shape) + + for num_axes in range(1, len(shape)): + for axis in itertools.combinations(range(len(shape)), num_axes): + for keepdims in [True, False]: + out_np = np.linalg.norm(x_np, axis=axis, keepdims=keepdims) + out_mx = mx.linalg.norm(x_mx, axis=axis, keepdims=keepdims) + assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) + + # Test only ord provided + for shape in [(3,), (2, 3)]: + x_mx = mx.arange(math.prod(shape)).reshape(shape) + x_np = np.arange(math.prod(shape)).reshape(shape) + for o in [None, 1, -1, float("inf"), -float("inf")]: + for keepdims in [True, False]: + out_np = np.linalg.norm(x_np, ord=o, keepdims=keepdims) + out_mx = mx.linalg.norm(x_mx, ord=o, keepdims=keepdims) + assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) if __name__ == "__main__": From bbfe042a2bf2c8bbd1d3dedf6b7fc0fb6633b2db Mon Sep 17 00:00:00 2001 From: Gabrijel Boduljak Date: Fri, 22 Dec 2023 12:00:18 +0100 Subject: [PATCH 31/37] fixed a bug with no ord and axis provided --- python/src/linalg.cpp | 27 +++++++++++++++++++++++---- python/tests/test_linalg.py | 16 ++++++++++++++++ 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp index d96dd8a2d..de5ccfcf3 100644 --- a/python/src/linalg.cpp +++ b/python/src/linalg.cpp @@ -44,7 +44,9 @@ void init_linalg(py::module_& parent_module) { return norm( a, "inf", - get_reduce_axes(axis, a.ndim()), + std::holds_alternative(axis) + ? std::vector() + : get_reduce_axes(axis, a.ndim()), keepdims, stream); } @@ -56,15 +58,32 @@ void init_linalg(py::module_& parent_module) { stream); } return norm( - a, p, get_reduce_axes(axis, a.ndim()), keepdims, stream); + a, + p, + std::holds_alternative(axis) + ? std::vector() + : get_reduce_axes(axis, a.ndim()), + keepdims, + stream); }, [&](const std::string& p) { return norm( - a, p, get_reduce_axes(axis, a.ndim()), keepdims, stream); + a, + p, + std::holds_alternative(axis) + ? std::vector() + : get_reduce_axes(axis, a.ndim()), + keepdims, + stream); }, [&](const std::monostate _) { return norm( - a, get_reduce_axes(axis, a.ndim()), keepdims, stream); + a, + std::holds_alternative(axis) + ? std::vector() + : get_reduce_axes(axis, a.ndim()), + keepdims, + stream); }}, ord); }, diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index 6c6d34699..ce1926de0 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -24,6 +24,12 @@ class TestLinalg(mlx_tests.MLXTestCase): ords = vector_ords else: ords = matrix_ords + for keepdims in [True, False]: + # Test axis provided, no ord provided + out_np = np.linalg.norm(x_np, axis=axis, keepdims=keepdims) + out_mx = mx.linalg.norm(x_mx, axis=axis, keepdims=keepdims) + assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) + # Test both ord and axis provided for o in ords: for keepdims in [True, False]: if o: @@ -64,6 +70,16 @@ class TestLinalg(mlx_tests.MLXTestCase): out_mx = mx.linalg.norm(x_mx, ord=o, keepdims=keepdims) assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) + # Test no ord and no axis provided + for shape in [(3,), (2, 3), (2, 3, 3)]: + x_mx = mx.arange(math.prod(shape)).reshape(shape) + x_np = np.arange(math.prod(shape)).reshape(shape) + for o in [None, 1, -1, float("inf"), -float("inf")]: + for keepdims in [True, False]: + out_np = np.linalg.norm(x_np, keepdims=keepdims) + out_mx = mx.linalg.norm(x_mx, keepdims=keepdims) + assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) + if __name__ == "__main__": unittest.main() From 4bae4a8239cc64a8be7e038d3e096253c03708ab Mon Sep 17 00:00:00 2001 From: Gabrijel Boduljak Date: Sun, 24 Dec 2023 04:40:44 +0100 Subject: [PATCH 32/37] removed unused imports --- mlx/linalg.cpp | 16 +++++++--------- mlx/linalg.h | 3 --- mlx/utils.cpp | 6 ------ 3 files changed, 7 insertions(+), 18 deletions(-) diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index f541c6214..33fa92083 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -1,7 +1,5 @@ // Copyright © 2023 Apple Inc. -#include -#include #include #include #include @@ -9,7 +7,7 @@ #include "mlx/array.h" #include "mlx/linalg.h" #include "mlx/ops.h" -#include "utils.h" +#include "mlx/utils.h" namespace mlx::core::linalg { @@ -41,7 +39,7 @@ inline array vector_norm( return max(abs(a, s), axis, keepdims, s); else if (ord == "-inf") return min(abs(a, s), axis, keepdims, s); - std::stringstream error_stream; + std::ostringstream error_stream; error_stream << "Invalid ord value " << ord; throw std::invalid_argument(error_stream.str()); } @@ -62,7 +60,7 @@ inline array matrix_norm( return max(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s); if (ord == 2.0 || ord == -2.0) throw std::logic_error("Singular value norms are not implemented."); - std::stringstream error_stream; + std::ostringstream error_stream; error_stream << "Invalid ord value " << ord << " for matrix norm"; throw std::invalid_argument(error_stream.str()); } @@ -81,7 +79,7 @@ inline array matrix_norm( return matrix_norm(a, -1.0, {axis[1], axis[0]}, keepdims, s); if (ord == "nuc") throw std::logic_error("Nuclear norm is not implemented."); - std::stringstream error_stream; + std::ostringstream error_stream; error_stream << "Invalid ord value " << ord << " for matrix norm"; throw std::invalid_argument(error_stream.str()); } @@ -101,7 +99,7 @@ array norm( s), s); - std::stringstream error_stream; + std::ostringstream error_stream; error_stream << "Invalid axis values " << axis; throw std::invalid_argument(error_stream.str()); } @@ -125,7 +123,7 @@ array norm( else if (num_axes == 2) return matrix_norm(a, ord, ax, keepdims, s); - std::stringstream error_stream; + std::ostringstream error_stream; error_stream << "Invalid axis values " << ax; throw std::invalid_argument(error_stream.str()); } @@ -149,7 +147,7 @@ array norm( else if (num_axes == 2) return matrix_norm(a, ord, ax, keepdims, s); - std::stringstream error_stream; + std::ostringstream error_stream; error_stream << "Invalid axis values " << ax; throw std::invalid_argument(error_stream.str()); } diff --git a/mlx/linalg.h b/mlx/linalg.h index 690df343c..d77ada477 100644 --- a/mlx/linalg.h +++ b/mlx/linalg.h @@ -2,13 +2,10 @@ #pragma once -#include - #include "array.h" #include "device.h" #include "ops.h" #include "stream.h" -#include "string.h" namespace mlx::core::linalg { array norm( diff --git a/mlx/utils.cpp b/mlx/utils.cpp index ddcb41ba8..932217ad4 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -286,10 +286,4 @@ std::ostream& operator<<(std::ostream& os, const std::vector& v) { return os; } -std::vector get_shape_reducing_over_all_axes(int ndim) { - std::vector shape(ndim); - std::iota(shape.begin(), shape.end(), 0); - return shape; -} - } // namespace mlx::core From f7cea9563df25a7b6d63a78f0b6cec99a4fa2e5c Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 26 Dec 2023 10:54:59 -0800 Subject: [PATCH 33/37] some style and API consistency updates to linalg norm --- docs/src/index.rst | 1 + docs/src/python/linalg.rst | 4 +- mlx/linalg.cpp | 196 ++++++++++---------- mlx/linalg.h | 42 ++++- mlx/utils.cpp | 7 - mlx/utils.h | 16 -- python/src/fft.cpp | 1 - python/src/linalg.cpp | 293 ++++++++++++++---------------- python/src/utils.h | 15 ++ tests/linalg_tests.cpp | 360 ++++++++++++++++--------------------- 10 files changed, 437 insertions(+), 498 deletions(-) diff --git a/docs/src/index.rst b/docs/src/index.rst index ac4932f10..207238f37 100644 --- a/docs/src/index.rst +++ b/docs/src/index.rst @@ -57,6 +57,7 @@ are the CPU and GPU. python/random python/transforms python/fft + python/linalg python/nn python/optimizers python/tree_utils diff --git a/docs/src/python/linalg.rst b/docs/src/python/linalg.rst index 6c9daa100..27746441e 100644 --- a/docs/src/python/linalg.rst +++ b/docs/src/python/linalg.rst @@ -1,11 +1,11 @@ .. _linalg: Linear Algebra -===== +============== .. currentmodule:: mlx.core.linalg .. autosummary:: :toctree: _autosummary - norm \ No newline at end of file + norm diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index 33fa92083..61c9e8537 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -1,47 +1,42 @@ // Copyright © 2023 Apple Inc. -#include -#include +#include +#include #include #include "mlx/array.h" #include "mlx/linalg.h" #include "mlx/ops.h" -#include "mlx/utils.h" namespace mlx::core::linalg { +Dtype at_least_float(const Dtype& d) { + return is_floating_point(d) ? d : promote_types(d, float32); +} + inline array vector_norm( const array& a, const double ord, const std::vector& axis, bool keepdims, StreamOrDevice s) { - if (ord == 0.0) - return sum(a != 0, axis, keepdims, s); - else if (ord == 1.0) - return sum(abs(a, s), axis, keepdims, s); - else if (ord == 2.0) - return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s)); - else + auto dtype = at_least_float(a.dtype()); + if (ord == 0.0) { + return astype(sum(not_equal(a, array(0), s), axis, keepdims, s), dtype, s); + } else if (ord == 1.0) { + return astype(sum(abs(a, s), axis, keepdims, s), dtype, s); + } else if (ord == 2.0) { + return sqrt(sum(square(a, s), axis, keepdims, s), s); + } else if (ord == std::numeric_limits::infinity()) { + return astype(max(abs(a, s), axis, keepdims, s), dtype, s); + } else if (ord == -std::numeric_limits::infinity()) { + return astype(min(abs(a, s), axis, keepdims, s), dtype, s); + } else { return power( - sum(power(abs(a, s), array(ord), s), axis, keepdims, s), - array(1.0 / ord)); -} - -inline array vector_norm( - const array& a, - const std::string& ord, - const std::vector& axis, - bool keepdims, - StreamOrDevice s) { - if (ord == "inf") - return max(abs(a, s), axis, keepdims, s); - else if (ord == "-inf") - return min(abs(a, s), axis, keepdims, s); - std::ostringstream error_stream; - error_stream << "Invalid ord value " << ord; - throw std::invalid_argument(error_stream.str()); + sum(power(abs(a, s), array(ord, dtype), s), axis, keepdims, s), + array(1.0 / ord, dtype), + s); + } } inline array matrix_norm( @@ -50,19 +45,30 @@ inline array matrix_norm( const std::vector& axis, bool keepdims, StreamOrDevice s) { + auto dtype = at_least_float(a.dtype()); auto row_axis = axis[0]; auto col_axis = axis[1]; - if (!keepdims && col_axis > row_axis) + if (!keepdims && col_axis > row_axis && col_axis > 0) { col_axis -= 1; - if (ord == -1.0) - return min(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s); - if (ord == 1.0) - return max(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s); - if (ord == 2.0 || ord == -2.0) - throw std::logic_error("Singular value norms are not implemented."); - std::ostringstream error_stream; - error_stream << "Invalid ord value " << ord << " for matrix norm"; - throw std::invalid_argument(error_stream.str()); + } + if (ord == -1.0) { + return astype( + min(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s), + dtype, + s); + } else if (ord == 1.0) { + return astype( + max(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s), + dtype, + s); + } else if (ord == 2.0 || ord == -2.0) { + throw std::runtime_error( + "[linalg::norm] Singular value norms are not implemented."); + } else { + std::ostringstream msg; + msg << "[linalg::norm] Invalid ord value " << ord << " for matrix norm"; + throw std::invalid_argument(msg.str()); + } } inline array matrix_norm( @@ -71,85 +77,77 @@ inline array matrix_norm( const std::vector& axis, bool keepdims, StreamOrDevice s) { - if (ord == "f" || ord == "fro") + if (ord == "f" || ord == "fro") { return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s)); - else if (ord == "inf") - return matrix_norm(a, 1.0, {axis[1], axis[0]}, keepdims, s); - else if (ord == "-inf") - return matrix_norm(a, -1.0, {axis[1], axis[0]}, keepdims, s); - if (ord == "nuc") - throw std::logic_error("Nuclear norm is not implemented."); - std::ostringstream error_stream; - error_stream << "Invalid ord value " << ord << " for matrix norm"; - throw std::invalid_argument(error_stream.str()); + } else if (ord == "nuc") { + throw std::runtime_error( + "[linalg::norm] Nuclear norm not yet implemented."); + } else { + std::ostringstream msg; + msg << "[linalg::norm] Invalid ord value '" << ord << "' for matrix norm"; + throw std::invalid_argument(msg.str()); + } } array norm( const array& a, - const std::vector& axis, - bool keepdims, - StreamOrDevice s) { - auto num_axes = axis.size(); + const std::optional>& axis /* = std::nullopt */, + bool keepdims /* = false */, + StreamOrDevice s /* = {} */) { + if (!axis) { + return norm(flatten(a, s), std::vector{0}, keepdims, s); + } - if (num_axes == 0 || num_axes == 1 || num_axes == 2) - return sqrt( - sum(abs(a, s) * abs(a, s), - num_axes ? axis : get_reduce_axes({}, a.ndim()), - keepdims, - s), - s); - - std::ostringstream error_stream; - error_stream << "Invalid axis values " << axis; - throw std::invalid_argument(error_stream.str()); + if (axis.value().size() > 2) { + throw std::invalid_argument( + "[linalg::norm] Received too many axes for norm"); + } + return sqrt(sum(square(a, s), axis.value(), keepdims, s), s); } array norm( const array& a, const double ord, - const std::vector& axis, - bool keepdims, - StreamOrDevice s) { - std::vector ax = axis; - - if (axis.empty()) - ax = get_reduce_axes({}, a.ndim()); - else - ax = normalize_axes(ax, a.ndim()); - - auto num_axes = ax.size(); - if (num_axes == 1) + const std::optional>& axis /* = std::nullopt */, + bool keepdims /* = false */, + StreamOrDevice s /* = {} */) { + std::vector ax; + if (!axis) { + ax.resize(a.ndim()); + std::iota(ax.begin(), ax.end(), 0); + } else { + ax = axis.value(); + } + if (ax.size() == 1) { return vector_norm(a, ord, ax, keepdims, s); - else if (num_axes == 2) + } else if (ax.size() == 2) { return matrix_norm(a, ord, ax, keepdims, s); - - std::ostringstream error_stream; - error_stream << "Invalid axis values " << ax; - throw std::invalid_argument(error_stream.str()); + } else { + throw std::invalid_argument( + "[linalg::norm] Received too many axes for norm"); + } } array norm( const array& a, const std::string& ord, - const std::vector& axis, - bool keepdims, - StreamOrDevice s) { - std::vector ax = axis; - - if (axis.empty()) - ax = get_reduce_axes({}, a.ndim()); - else - ax = normalize_axes(ax, a.ndim()); - - auto num_axes = ax.size(); - if (num_axes == 1) - return vector_norm(a, ord, ax, keepdims, s); - else if (num_axes == 2) - return matrix_norm(a, ord, ax, keepdims, s); - - std::ostringstream error_stream; - error_stream << "Invalid axis values " << ax; - throw std::invalid_argument(error_stream.str()); + const std::optional>& axis /* = std::nullopt */, + bool keepdims /* = false */, + StreamOrDevice s /* = {} */) { + std::vector ax; + if (!axis) { + ax.resize(a.ndim()); + std::iota(ax.begin(), ax.end(), 0); + } else { + ax = axis.value(); + } + if (ax.size() != 2) { + std::ostringstream msg; + msg << "[linalg::norm] Norm '" << ord << "' only supported for matrices," + << " but received " << ax.size() << " axis/axes."; + throw std::invalid_argument(msg.str()); + } + return matrix_norm(a, ord, ax, keepdims, s); } -} // namespace mlx::core::linalg \ No newline at end of file +} // namespace mlx::core::linalg diff --git a/mlx/linalg.h b/mlx/linalg.h index d77ada477..bf3b5e78c 100644 --- a/mlx/linalg.h +++ b/mlx/linalg.h @@ -2,27 +2,61 @@ #pragma once +#include + #include "array.h" #include "device.h" #include "ops.h" #include "stream.h" namespace mlx::core::linalg { + +/* + * Compute vector or matrix norms. + * + * - If axis and ord are both unspecified, computes the 2-norm of flatten(x). + * - If axis is not provided but ord is, then x must be either 1D or 2D. + * - If axis is provided, but ord is not, then the 2-norm is computed along the + * given axes. At most 2 axes can be specified. + * - If both axis and ord are provided, then the corresponding matrix of vector + * norm is computed. At most 2 axes can be specified. + */ array norm( const array& a, const double ord, - const std::vector& axis = {}, + const std::optional>& axis = std::nullopt, bool keepdims = false, StreamOrDevice s = {}); +inline array norm( + const array& a, + const double ord, + int axis, + bool keepdims = false, + StreamOrDevice s = {}) { + return norm(a, ord, std::vector{axis}, keepdims, s); +} array norm( const array& a, const std::string& ord, - const std::vector& axis = {}, + const std::optional>& axis = std::nullopt, bool keepdims = false, StreamOrDevice s = {}); +inline array norm( + const array& a, + const std::string& ord, + int axis, + bool keepdims = false, + StreamOrDevice s = {}) { + return norm(a, ord, std::vector{axis}, keepdims, s); +} array norm( const array& a, - const std::vector& axis = {}, + const std::optional>& axis = std::nullopt, bool keepdims = false, StreamOrDevice s = {}); -} // namespace mlx::core::linalg \ No newline at end of file +inline array +norm(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {}) { + return norm(a, std::vector{axis}, keepdims, s); +} + +} // namespace mlx::core::linalg diff --git a/mlx/utils.cpp b/mlx/utils.cpp index 932217ad4..1fbc67c8e 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -1,6 +1,5 @@ // Copyright © 2023 Apple Inc. -#include #include #include @@ -74,12 +73,6 @@ int normalize_axis(int axis, int ndim) { } return axis; } -std::vector normalize_axes(const std::vector& axes, int ndim) { - std::vector canonical; - for (int ax : axes) - canonical.push_back(normalize_axis(ax, ndim)); - return canonical; -} std::ostream& operator<<(std::ostream& os, const Device& d) { os << "Device("; diff --git a/mlx/utils.h b/mlx/utils.h index 0b0ae9e93..823b4c872 100644 --- a/mlx/utils.h +++ b/mlx/utils.h @@ -2,7 +2,6 @@ #pragma once -#include #include "array.h" #include "device.h" #include "dtype.h" @@ -25,7 +24,6 @@ bool is_same_shape(const std::vector& arrays); * https://numpy.org/devdocs/reference/generated/numpy.lib.array_utils.normalize_axis_index.html */ int normalize_axis(int axis, int ndim); -std::vector normalize_axes(const std::vector& axes, int ndim); std::ostream& operator<<(std::ostream& os, const Device& d); std::ostream& operator<<(std::ostream& os, const Stream& s); @@ -43,18 +41,4 @@ inline std::ostream& operator<<(std::ostream& os, const float16_t& v) { inline std::ostream& operator<<(std::ostream& os, const bfloat16_t& v) { return os << static_cast(v); } - -using IntOrVec = std::variant>; -inline std::vector get_reduce_axes(const IntOrVec& v, int dims) { - std::vector axes; - if (std::holds_alternative(v)) { - axes.resize(dims); - std::iota(axes.begin(), axes.end(), 0); - } else if (auto pv = std::get_if(&v); pv) { - axes.push_back(*pv); - } else { - axes = std::get>(v); - } - return axes; -} } // namespace mlx::core diff --git a/python/src/fft.cpp b/python/src/fft.cpp index 6b3739ae6..42ad37633 100644 --- a/python/src/fft.cpp +++ b/python/src/fft.cpp @@ -7,7 +7,6 @@ #include "mlx/fft.h" #include "mlx/ops.h" -#include "mlx/utils.h" namespace py = pybind11; using namespace py::literals; diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp index de5ccfcf3..7bd186d51 100644 --- a/python/src/linalg.cpp +++ b/python/src/linalg.cpp @@ -26,193 +26,164 @@ using namespace mlx::core; using namespace mlx::core::linalg; void init_linalg(py::module_& parent_module) { + py::options options; + options.disable_function_signatures(); + auto m = parent_module.def_submodule("linalg", "mlx.core.linalg: Linear Algebra."); m.def( "norm", [](const array& a, - const std::variant& ord, - const std::variant>& axis, + const std::variant& ord_, + const std::variant>& axis_, const bool keepdims, const StreamOrDevice stream) { - return std::visit( - overloaded{ - [&](const double p) { - if (std::isinf((float)p) || std::isinf(p)) { - if (p > 0) { - return norm( - a, - "inf", - std::holds_alternative(axis) - ? std::vector() - : get_reduce_axes(axis, a.ndim()), - keepdims, - stream); - } - return norm( - a, - "-inf", - get_reduce_axes(axis, a.ndim()), - keepdims, - stream); - } - return norm( - a, - p, - std::holds_alternative(axis) - ? std::vector() - : get_reduce_axes(axis, a.ndim()), - keepdims, - stream); - }, - [&](const std::string& p) { - return norm( - a, - p, - std::holds_alternative(axis) - ? std::vector() - : get_reduce_axes(axis, a.ndim()), - keepdims, - stream); - }, - [&](const std::monostate _) { - return norm( - a, - std::holds_alternative(axis) - ? std::vector() - : get_reduce_axes(axis, a.ndim()), - keepdims, - stream); - }}, - ord); + std::optional> axis = std::nullopt; + if (auto pv = std::get_if(&axis_); pv) { + axis = std::vector{*pv}; + } else if (auto pv = std::get_if>(&axis_); pv) { + axis = *pv; + } + + if (std::holds_alternative(ord_)) { + return norm(a, axis, keepdims, stream); + } else { + if (auto pv = std::get_if(&ord_); pv) { + return norm(a, *pv, axis, keepdims, stream); + } + double ord; + if (auto pv = std::get_if(&ord_); pv) { + ord = *pv; + } else { + ord = std::get(ord_); + } + return norm(a, ord, axis, keepdims, stream); + } }, "a"_a, + py::pos_only(), "ord"_a = none, "axis"_a = none, "keepdims"_a = false, + py::kw_only(), "stream"_a = none, R"pbdoc( - Matrix or vector norm. + norm(a: array, /, ord: Union[None, scalar, str] = None, axis: Union[None, int, List[int]] = None, keepdims: bool = False, *, stream: Union[None, Stream, Device] = None) -> array - This function is able to return matrix or vector norms, - depending on the value of the ``ord`` parameter. + Matrix or vector norm. - Parameters - ---------- - a : array_like - Input array. If `axis` is None, `a` must be 1-D or 2-D, unless `ord` - is None. If both `axis` and `ord` are None, the 2-norm of ``a.flatten`` will be returned. - ord : {non-zero int, inf, -inf, 'fro', 'nuc'}, optional - Order of the norm (see table under ``Notes``). inf means float(`inf`) object. The default is None. - axis : {None, int, 2-tuple of ints}, optional. - If `axis` is an integer, it specifies the axis of `a` along which to - compute the vector norms. If `axis` is a 2-tuple, it specifies the - axes that hold 2-D matrices, and the matrix norms of these matrices - are computed. If `axis` is None then either a vector norm (when `a` - is 1-D) or a matrix norm (when `a` is 2-D) is returned. The default - is None. - keepdims : bool, optional - If this is set to True, the axes which are normed over are left in the - result as dimensions with size one. With this option the result will - broadcast correctly against the original `a`. + This function computes vector or matrix norms depending on the value of + the ``ord`` and ``axis`` parameters. - Returns - ------- - n : array - Norm of the matrix or vector(s). + Args: + a (array): Input array. If ``axis`` is ``None``, ``a`` must be 1-D or 2-D, + unless ``ord`` is ``None``. If both ``axis`` and ``ord`` are ``None``, the + 2-norm of ``a.flatten`` will be returned. + ord (scalar or str, optional): Order of the norm (see table under ``Notes``). + If ``None``, the 2-norm will be computed along the given ``axis``. + Default: ``None``. + axis (int or list(int), optional): If ``axis`` is an integer, it specifies the + axis of ``a`` along which to compute the vector norms. If ``axis`` is a + 2-tuple, it specifies the axes that hold 2-D matrices, and the matrix + norms of these matrices are computed. If `axis` is ``None`` then + either a vector norm (when ``a`` is 1-D) or a matrix norm (when ``a`` is + 2-D) is returned. Default: ``None``. + keepdims (bool, optional): If ``True``, the axes which are normed over are + left in the result as dimensions with size one. Default ``False``. - Notes - ----- - For values of ``ord < 1``, the result is, strictly speaking, not a - mathematical 'norm', but it may still be useful for various numerical - purposes. + Returns: + array: The output containing the norm(s). - The following norms can be calculated: + Notes: + For values of ``ord < 1``, the result is, strictly speaking, not a + mathematical norm, but it may still be useful for various numerical + purposes. - ===== ============================ ========================== - ord norm for matrices norm for vectors - ===== ============================ ========================== - None Frobenius norm 2-norm - 'fro' Frobenius norm -- - inf max(sum(abs(x), axis=1)) max(abs(x)) - -inf min(sum(abs(x), axis=1)) min(abs(x)) - 0 -- sum(x != 0) - 1 max(sum(abs(x), axis=0)) as below - -1 min(sum(abs(x), axis=0)) as below - 2 2-norm (largest sing. value) as below - -2 smallest singular value as below - other -- sum(abs(x)**ord)**(1./ord) - ===== ============================ ========================== + The following norms can be calculated: - Nuclear norm and norms based on singular values are not yet implemented. + ===== ============================ ========================== + ord norm for matrices norm for vectors + ===== ============================ ========================== + None Frobenius norm 2-norm + 'fro' Frobenius norm -- + inf max(sum(abs(x), axis=1)) max(abs(x)) + -inf min(sum(abs(x), axis=1)) min(abs(x)) + 0 -- sum(x != 0) + 1 max(sum(abs(x), axis=0)) as below + -1 min(sum(abs(x), axis=0)) as below + 2 2-norm (largest sing. value) as below + -2 smallest singular value as below + other -- sum(abs(x)**ord)**(1./ord) + ===== ============================ ========================== - The Frobenius norm is given by [1]_: + .. warning:: + Nuclear norm and norms based on singular values are not yet implemented. - :math:`||A||_F = [\\sum_{i,j} abs(a_{i,j})^2]^{1/2}` + The Frobenius norm is given by [1]_: - The nuclear norm is the sum of the singular values. + :math:`||A||_F = [\sum_{i,j} abs(a_{i,j})^2]^{1/2}` - Both the Frobenius and nuclear norm orders are only defined for - matrices and raise a ValueError when ``a.ndim != 2``. + The nuclear norm is the sum of the singular values. - References - ---------- - .. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*, - Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15 + Both the Frobenius and nuclear norm orders are only defined for + matrices and raise a ``ValueError`` when ``a.ndim != 2``. - Examples - -------- - >>> import mlx.core as mx - >>> from mlx.core import linalg as LA - >>> a = mx.arange(9) - 4 - >>> a - array([-4, -3, -2, ..., 2, 3, 4], dtype=int32) - >>> b = a.reshape((3,3)) - >>> b - array([[-4, -3, -2], - [-1, 0, 1], - [ 2, 3, 4]], dtype=int32) - >>> LA.norm(a) - array(7.74597, dtype=float32) - >>> LA.norm(b) - array(7.74597, dtype=float32) - >>> LA.norm(b, 'fro') - array(7.74597, dtype=float32) - >>> LA.norm(a, float("inf")) - array(4, dtype=int32) - >>> LA.norm(b, float("inf")) - array(9, dtype=int32) - >>> LA.norm(a, -float("inf")) - array(0, dtype=int32) - >>> LA.norm(b, -float("inf")) - array(2, dtype=int32) - >>> LA.norm(a, 1) - array(20, dtype=int32) - >>> LA.norm(b, 1) - array(7, dtype=int32) - >>> LA.norm(a, -1) - array(0, dtype=float32) - >>> LA.norm(b, -1) - array(6, dtype=int32) - >>> LA.norm(a, 2) - array(7.74597, dtype=float32) - >>> LA.norm(a, 3) - array(5.84804, dtype=float32) - >>> LA.norm(a, -3) - array(0, dtype=float32) - >>> c = mx.array([[ 1, 2, 3], - ... [-1, 1, 4]]) - >>> LA.norm(c, axis=0) - array([1.41421, 2.23607, 5], dtype=float32) - >>> LA.norm(c, axis=1) - array([3.74166, 4.24264], dtype=float32) - >>> LA.norm(c, ord=1, axis=1) - array([6, 6], dtype=int32) - >>> m = mx.arange(8).reshape(2,2,2) - >>> LA.norm(m, axis=(1,2)) - array([3.74166, 11.225], dtype=float32) - >>> LA.norm(m[0, :, :]), LA.norm(m[1, :, :]) - (array(3.74166, dtype=float32), array(11.225, dtype=float32)) - )pbdoc"); + References: + .. [1] G. H. Golub and C. F. Van Loan, *Matrix Computations*, + Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15 + + Examples: + >>> import mlx.core as mx + >>> from mlx.core import linalg as la + >>> a = mx.arange(9) - 4 + >>> a + array([-4, -3, -2, ..., 2, 3, 4], dtype=int32) + >>> b = a.reshape((3,3)) + >>> b + array([[-4, -3, -2], + [-1, 0, 1], + [ 2, 3, 4]], dtype=int32) + >>> la.norm(a) + array(7.74597, dtype=float32) + >>> la.norm(b) + array(7.74597, dtype=float32) + >>> la.norm(b, 'fro') + array(7.74597, dtype=float32) + >>> la.norm(a, float("inf")) + array(4, dtype=float32) + >>> la.norm(b, float("inf")) + array(9, dtype=float32) + >>> la.norm(a, -float("inf")) + array(0, dtype=float32) + >>> la.norm(b, -float("inf")) + array(2, dtype=float32) + >>> la.norm(a, 1) + array(20, dtype=float32) + >>> la.norm(b, 1) + array(7, dtype=float32) + >>> la.norm(a, -1) + array(0, dtype=float32) + >>> la.norm(b, -1) + array(6, dtype=float32) + >>> la.norm(a, 2) + array(7.74597, dtype=float32) + >>> la.norm(a, 3) + array(5.84804, dtype=float32) + >>> la.norm(a, -3) + array(0, dtype=float32) + >>> c = mx.array([[ 1, 2, 3], + ... [-1, 1, 4]]) + >>> la.norm(c, axis=0) + array([1.41421, 2.23607, 5], dtype=float32) + >>> la.norm(c, axis=1) + array([3.74166, 4.24264], dtype=float32) + >>> la.norm(c, ord=1, axis=1) + array([6, 6], dtype=float32) + >>> m = mx.arange(8).reshape(2,2,2) + >>> la.norm(m, axis=(1,2)) + array([3.74166, 11.225], dtype=float32) + >>> la.norm(m[0, :, :]), LA.norm(m[1, :, :]) + (array(3.74166, dtype=float32), array(11.225, dtype=float32)) + )pbdoc"); } diff --git a/python/src/utils.h b/python/src/utils.h index 9751b2d6e..5ac878979 100644 --- a/python/src/utils.h +++ b/python/src/utils.h @@ -1,6 +1,7 @@ // Copyright © 2023 Apple Inc. #pragma once +#include #include #include @@ -13,10 +14,24 @@ namespace py = pybind11; using namespace mlx::core; +using IntOrVec = std::variant>; using ScalarOrArray = std:: variant, py::object>; static constexpr std::monostate none{}; +inline std::vector get_reduce_axes(const IntOrVec& v, int dims) { + std::vector axes; + if (std::holds_alternative(v)) { + axes.resize(dims); + std::iota(axes.begin(), axes.end(), 0); + } else if (auto pv = std::get_if(&v); pv) { + axes.push_back(*pv); + } else { + axes = std::get>(v); + } + return axes; +} + inline array to_array_with_accessor(py::object obj) { if (py::hasattr(obj, "__mlx_array__")) { return obj.attr("__mlx_array__")().cast(); diff --git a/tests/linalg_tests.cpp b/tests/linalg_tests.cpp index 9841f03bf..1d8ee43d9 100644 --- a/tests/linalg_tests.cpp +++ b/tests/linalg_tests.cpp @@ -3,170 +3,155 @@ #include "doctest/doctest.h" #include -#include -#include "mlx/linalg.h" + #include "mlx/mlx.h" using namespace mlx::core; using namespace mlx::core::linalg; TEST_CASE("[mlx.core.linalg.norm] no ord") { - array arr_one_d({1, 2, 3}); - array arr_two_d = reshape(arange(9), {3, 3}); - array arr_three_d = reshape(arange(18), {2, 3, 3}); + // Zero dimensions + array x(2.0); + CHECK_EQ(norm(x).item(), 2.0f); + CHECK_THROWS(norm(x, 0)); - CHECK(array_equal(norm(arr_one_d), array(sqrt(1 + 4 + 9))).item()); - CHECK(array_equal(norm(arr_one_d, {0}, false), array(sqrt(1 + 4 + 9))) - .item()); + x = array({1, 2, 3}); + float expected = std::sqrt(1 + 4 + 9); + CHECK_EQ(norm(x).item(), doctest::Approx(expected)); + CHECK_EQ(norm(x, 0, false).item(), doctest::Approx(expected)); + CHECK_EQ(norm(x, -1, false).item(), doctest::Approx(expected)); + CHECK_EQ(norm(x, -1, true).ndim(), 1); + CHECK_THROWS(norm(x, 1)); + + x = reshape(arange(9), {3, 3}); + expected = + std::sqrt(0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + 8 * 8); + + CHECK_EQ(norm(x).item(), doctest::Approx(expected)); + CHECK_EQ( + norm(x, std::vector{0, 1}).item(), doctest::Approx(expected)); CHECK(array_equal( - norm(arr_two_d, {}, false), - array(sqrt( - 0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + 8 * 8))) - .item()); - CHECK(array_equal( - norm(arr_two_d, {0}, false), + norm(x, 0, false), array( - {sqrt(0 + 3 * 3 + 6 * 6), - sqrt(1 + 4 * 4 + 7 * 7), - sqrt(2 * 2 + 5 * 5 + 8 * 8)})) + {std::sqrt(0 + 3 * 3 + 6 * 6), + std::sqrt(1 + 4 * 4 + 7 * 7), + std::sqrt(2 * 2 + 5 * 5 + 8 * 8)})) .item()); - CHECK(array_equal( - norm(arr_two_d, {1}, false), + CHECK(allclose( + norm(x, 1, false), array( - {sqrt(0 + 1 + 2 * 2), - sqrt(3 * 3 + 4 * 4 + 5 * 5), - sqrt(6 * 6 + 7 * 7 + 8 * 8)})) + {std::sqrt(0 + 1 + 2 * 2), + std::sqrt(3 * 3 + 4 * 4 + 5 * 5), + std::sqrt(6 * 6 + 7 * 7 + 8 * 8)})) .item()); - CHECK(array_equal( - norm(arr_two_d, {0, 1}, false), - array(sqrt( - 0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + 8 * 8))) - .item()); - CHECK(array_equal( - norm(arr_three_d, {2}, false), + + x = reshape(arange(18), {2, 3, 3}); + CHECK(allclose( + norm(x, 2, false), array( { - sqrt(0 + 1 + 2 * 2), - sqrt(3 * 3 + 4 * 4 + 5 * 5), - sqrt(6 * 6 + 7 * 7 + 8 * 8), - sqrt(9 * 9 + 10 * 10 + 11 * 11), - sqrt(12 * 12 + 13 * 13 + 14 * 14), - sqrt(15 * 15 + 16 * 16 + 17 * 17), + std::sqrt(0 + 1 + 2 * 2), + std::sqrt(3 * 3 + 4 * 4 + 5 * 5), + std::sqrt(6 * 6 + 7 * 7 + 8 * 8), + std::sqrt(9 * 9 + 10 * 10 + 11 * 11), + std::sqrt(12 * 12 + 13 * 13 + 14 * 14), + std::sqrt(15 * 15 + 16 * 16 + 17 * 17), }, {2, 3})) .item()); - CHECK(array_equal( - norm(arr_three_d, {1}, false), + CHECK(allclose( + norm(x, std::vector{1, 2}, false), array( - { - sqrt(0 + 3 * 3 + 6 * 6), - sqrt(1 + 4 * 4 + 7 * 7), - sqrt(2 * 2 + 5 * 5 + 8 * 8), - sqrt(9 * 9 + 12 * 12 + 15 * 15), - sqrt(10 * 10 + 13 * 13 + 16 * 16), - sqrt(11 * 11 + 14 * 14 + 17 * 17), - }, - {2, 3})) - .item()); - CHECK(array_equal( - norm(arr_three_d, {0}, false), - array( - { - sqrt(0 + 9 * 9), - sqrt(1 + 10 * 10), - sqrt(2 * 2 + 11 * 11), - sqrt(3 * 3 + 12 * 12), - sqrt(4 * 4 + 13 * 13), - sqrt(5 * 5 + 14 * 14), - sqrt(6 * 6 + 15 * 15), - sqrt(7 * 7 + 16 * 16), - sqrt(8 * 8 + 17 * 17), - }, - {3, 3})) - .item()); - CHECK(array_equal( - norm(arr_three_d, {1, 2}, false), - array( - {sqrt( + {std::sqrt( 0 + 1 + 2 * 2 + 3 * 3 + 4 * 4 + 5 * 5 + 6 * 6 + 7 * 7 + 8 * 8), - sqrt( + std::sqrt( 9 * 9 + 10 * 10 + 11 * 11 + 12 * 12 + 13 * 13 + 14 * 14 + 15 * 15 + 16 * 16 + 17 * 17)}, {2})) .item()); + CHECK_THROWS(norm(x, std::vector{0, 1, 2})); } TEST_CASE("[mlx.core.linalg.norm] double ord") { - array arr_one_d({1, 2, 3}); - array arr_two_d = reshape(arange(9), {3, 3}); - array arr_three_d = reshape(arange(18), {2, 3, 3}); + CHECK_THROWS(norm(array(0), 2.0)); - CHECK(array_equal(norm(arr_one_d, 2.0), array(sqrt(1 + 4 + 9))).item()); - CHECK(array_equal(norm(arr_one_d, 1.0), array(1 + 2 + 3)).item()); - CHECK(array_equal(norm(arr_one_d, 0.0), array(3)).item()); + array x({1, 2, 3}); - CHECK(array_equal(norm(arr_one_d, 2.0, {0}, false), array(sqrt(1 + 4 + 9))) - .item()); - CHECK(array_equal( - norm(arr_two_d, 2.0, {0}, false), + float expected = std::sqrt(1 + 4 + 9); + CHECK_EQ(norm(x, 2.0).item(), doctest::Approx(expected)); + CHECK_EQ(norm(x, 2.0, 0).item(), doctest::Approx(expected)); + CHECK_THROWS(norm(x, 2.0, 1)); + + expected = 1 + 2 + 3; + CHECK_EQ(norm(x, 1.0).item(), doctest::Approx(expected)); + + expected = 3; + CHECK_EQ(norm(x, 0.0).item(), doctest::Approx(expected)); + + expected = 3; + CHECK_EQ( + norm(x, std::numeric_limits::infinity()).item(), + doctest::Approx(expected)); + + expected = 1; + CHECK_EQ( + norm(x, -std::numeric_limits::infinity()).item(), + doctest::Approx(expected)); + + x = reshape(arange(9), {3, 3}); + + CHECK(allclose( + norm(x, 2.0, 0, false), array( - {sqrt(0 + 3 * 3 + 6 * 6), - sqrt(1 + 4 * 4 + 7 * 7), - sqrt(2 * 2 + 5 * 5 + 8 * 8)})) + {std::sqrt(0 + 3 * 3 + 6 * 6), + std::sqrt(1 + 4 * 4 + 7 * 7), + std::sqrt(2 * 2 + 5 * 5 + 8 * 8)})) .item()); - CHECK(array_equal( - norm(arr_two_d, 2.0, {1}, false), + CHECK(allclose( + norm(x, 2.0, 1, false), array( {sqrt(0 + 1 + 2 * 2), sqrt(3 * 3 + 4 * 4 + 5 * 5), sqrt(6 * 6 + 7 * 7 + 8 * 8)})) .item()); - CHECK(array_equal( - norm(arr_three_d, 2.0, {2}, false), - array( - { - sqrt(0 + 1 + 2 * 2), - sqrt(3 * 3 + 4 * 4 + 5 * 5), - sqrt(6 * 6 + 7 * 7 + 8 * 8), - sqrt(9 * 9 + 10 * 10 + 11 * 11), - sqrt(12 * 12 + 13 * 13 + 14 * 14), - sqrt(15 * 15 + 16 * 16 + 17 * 17), - }, - {2, 3})) - .item()); - CHECK(array_equal( - norm(arr_three_d, 2.0, {1}, false), - array( - { - sqrt(0 + 3 * 3 + 6 * 6), - sqrt(1 + 4 * 4 + 7 * 7), - sqrt(2 * 2 + 5 * 5 + 8 * 8), - sqrt(9 * 9 + 12 * 12 + 15 * 15), - sqrt(10 * 10 + 13 * 13 + 16 * 16), - sqrt(11 * 11 + 14 * 14 + 17 * 17), - }, - {2, 3})) - .item()); - CHECK(array_equal( - norm(arr_three_d, 2.0, {0}, false), - array( - { - sqrt(0 + 9 * 9), - sqrt(1 + 10 * 10), - sqrt(2 * 2 + 11 * 11), - sqrt(3 * 3 + 12 * 12), - sqrt(4 * 4 + 13 * 13), - sqrt(5 * 5 + 14 * 14), - sqrt(6 * 6 + 15 * 15), - sqrt(7 * 7 + 16 * 16), - sqrt(8 * 8 + 17 * 17), - }, - {3, 3})) - .item()); + CHECK_EQ( + norm(x, 1.0, std::vector{0, 1}).item(), + doctest::Approx(15.0)); + CHECK_EQ( + norm(x, 1.0, std::vector{1, 0}).item(), + doctest::Approx(21.0)); + CHECK_EQ( + norm(x, -1.0, std::vector{0, 1}).item(), + doctest::Approx(9.0)); + CHECK_EQ( + norm(x, -1.0, std::vector{1, 0}).item(), + doctest::Approx(3.0)); + CHECK_EQ( + norm(x, 1.0, std::vector{0, 1}, true).shape(), + std::vector{1, 1}); + CHECK_EQ( + norm(x, 1.0, std::vector{1, 0}, true).shape(), + std::vector{1, 1}); + CHECK_EQ( + norm(x, -1.0, std::vector{0, 1}, true).shape(), + std::vector{1, 1}); + CHECK_EQ( + norm(x, -1.0, std::vector{1, 0}, true).shape(), + std::vector{1, 1}); + + CHECK_EQ( + norm(x, -1.0, std::vector{-2, -1}, false).item(), + doctest::Approx(9.0)); + CHECK_EQ( + norm(x, 1.0, std::vector{-2, -1}, false).item(), + doctest::Approx(15.0)); + + x = reshape(arange(18), {2, 3, 3}); + CHECK_THROWS(norm(x, 2.0, std::vector{0, 1, 2})); CHECK(allclose( - norm(arr_three_d, 3.0, {0}), + norm(x, 3.0, 0), array( {9., 10.00333222, @@ -179,15 +164,8 @@ TEST_CASE("[mlx.core.linalg.norm] double ord") { 17.57113899}, {3, 3})) .item()); - CHECK( - allclose( - norm(arr_three_d, 3.0, {1}), - array( - {6.24025147, 7.41685954, 8.6401226, 18., 19.39257164, 20.7915893}, - {2, 3})) - .item()); CHECK(allclose( - norm(arr_three_d, 3.0, {2}), + norm(x, 3.0, 2), array( {2.08008382, 6., @@ -197,110 +175,76 @@ TEST_CASE("[mlx.core.linalg.norm] double ord") { 23.13593104}, {2, 3})) .item()); - CHECK(allclose( - norm(arr_three_d, 0.0, {0}), - array({1., 2., 2., 2., 2., 2., 2., 2., 2.}, {3, 3})) + CHECK( + allclose( + norm(x, 0.0, 0), array({1., 2., 2., 2., 2., 2., 2., 2., 2.}, {3, 3})) + .item()); + CHECK(allclose(norm(x, 0.0, 1), array({2., 3., 3., 3., 3., 3.}, {2, 3})) + .item()); + CHECK(allclose(norm(x, 0.0, 2), array({2., 3., 3., 3., 3., 3.}, {2, 3})) .item()); - CHECK( - allclose( - norm(arr_three_d, 0.0, {1}), array({2., 3., 3., 3., 3., 3.}, {2, 3})) - .item()); - CHECK( - allclose( - norm(arr_three_d, 0.0, {2}), array({2., 3., 3., 3., 3., 3.}, {2, 3})) - .item()); CHECK(allclose( - norm(arr_three_d, 1.0, {0}), + norm(x, 1.0, 0), array({9., 11., 13., 15., 17., 19., 21., 23., 25.}, {3, 3})) .item()); - CHECK(allclose( - norm(arr_three_d, 1.0, {1}), - array({9., 12., 15., 36., 39., 42.}, {2, 3})) + CHECK(allclose(norm(x, 1.0, 1), array({9., 12., 15., 36., 39., 42.}, {2, 3})) .item()); - CHECK(allclose( - norm(arr_three_d, 1.0, {2}), - array({3., 12., 21., 30., 39., 48.}, {2, 3})) + CHECK(allclose(norm(x, 1.0, 2), array({3., 12., 21., 30., 39., 48.}, {2, 3})) .item()); - CHECK(allclose(norm(arr_two_d, 1.0, {0, 1}), array({15.0})).item()); - CHECK(allclose(norm(arr_two_d, 1.0, {1, 0}), array({21.0})).item()); - CHECK(allclose(norm(arr_two_d, -1.0, {0, 1}), array({9.0})).item()); - CHECK(allclose(norm(arr_two_d, -1.0, {1, 0}), array({3.0})).item()); - - CHECK(allclose(norm(arr_two_d, 1.0, {0, 1}, true), array({15.0}, {1, 1})) + CHECK(allclose(norm(x, 1.0, std::vector{0, 1}), array({21., 23., 25.})) .item()); - CHECK(allclose(norm(arr_two_d, 1.0, {1, 0}, true), array({21.0}, {1, 1})) + CHECK(allclose(norm(x, 1.0, std::vector{1, 2}), array({15., 42.})) .item()); - CHECK(allclose(norm(arr_two_d, -1.0, {0, 1}, true), array({9.0}, {1, 1})) + CHECK(allclose(norm(x, -1.0, std::vector{0, 1}), array({9., 11., 13.})) .item()); - CHECK(allclose(norm(arr_two_d, -1.0, {1, 0}, true), array({3.0}, {1, 1})) + CHECK(allclose(norm(x, -1.0, std::vector{1, 2}), array({9., 36.})) .item()); - - CHECK(array_equal(norm(arr_two_d, -1.0, {-2, -1}, false), array(9.0)) + CHECK(allclose(norm(x, -1.0, std::vector{1, 0}), array({9., 12., 15.})) .item()); - CHECK(array_equal(norm(arr_two_d, 1.0, {-2, -1}, false), array(15.0)) + CHECK(allclose(norm(x, -1.0, std::vector{2, 1}), array({3, 30})) .item()); - // - CHECK(allclose(norm(arr_three_d, 1.0, {0, 1}), array({21., 23., 25.})) + CHECK(allclose(norm(x, -1.0, std::vector{1, 2}), array({9, 36})) .item()); - CHECK( - allclose(norm(arr_three_d, 1.0, {1, 2}), array({15., 42.})).item()); - CHECK(allclose(norm(arr_three_d, -1.0, {0, 1}), array({9., 11., 13.})) - .item()); - CHECK( - allclose(norm(arr_three_d, -1.0, {1, 2}), array({9., 36.})).item()); - CHECK(allclose(norm(arr_three_d, -1.0, {1, 0}), array({9., 12., 15.})) - .item()); - CHECK(allclose(norm(arr_three_d, -1.0, {2, 1}), array({3, 30})).item()); - CHECK(allclose(norm(arr_three_d, -1.0, {1, 2}), array({9, 36})).item()); } TEST_CASE("[mlx.core.linalg.norm] string ord") { - array arr_one_d({1, 2, 3}); - array arr_two_d = reshape(arange(9), {3, 3}); - array arr_three_d = reshape(arange(18), {2, 3, 3}); + array x({1, 2, 3}); + CHECK_THROWS(norm(x, "fro")); - CHECK(allclose(norm(arr_one_d, "inf", {}), array({3.0})).item()); - CHECK(allclose(norm(arr_one_d, "-inf", {}), array({1.0})).item()); + x = reshape(arange(9), {3, 3}); + CHECK_THROWS(norm(x, "bad ord")); - CHECK(allclose(norm(arr_two_d, "f", {0, 1}), array({14.2828568570857})) - .item()); - CHECK(allclose(norm(arr_two_d, "fro", {0, 1}), array({14.2828568570857})) - .item()); - CHECK(allclose(norm(arr_two_d, "inf", {0, 1}), array({21.0})).item()); - CHECK(allclose(norm(arr_two_d, "-inf", {0, 1}), array({3.0})).item()); + CHECK_EQ( + norm(x, "f", std::vector{0, 1}).item(), + doctest::Approx(14.2828568570857)); + CHECK_EQ( + norm(x, "fro", std::vector{0, 1}).item(), + doctest::Approx(14.2828568570857)); + x = reshape(arange(18), {2, 3, 3}); CHECK(allclose( - norm(arr_three_d, "fro", {0, 1}), + norm(x, "fro", std::vector{0, 1}), array({22.24859546, 24.31049156, 26.43860813})) .item()); CHECK(allclose( - norm(arr_three_d, "fro", {1, 2}), array({14.28285686, 39.7617907})) + norm(x, "fro", std::vector{1, 2}), + array({14.28285686, 39.7617907})) .item()); CHECK(allclose( - norm(arr_three_d, "f", {0, 1}), + norm(x, "f", std::vector{0, 1}), array({22.24859546, 24.31049156, 26.43860813})) .item()); CHECK(allclose( - norm(arr_three_d, "f", {1, 0}), + norm(x, "f", std::vector{1, 0}), array({22.24859546, 24.31049156, 26.43860813})) .item()); - CHECK( - allclose(norm(arr_three_d, "f", {1, 2}), array({14.28285686, 39.7617907})) - .item()); - CHECK( - allclose(norm(arr_three_d, "f", {2, 1}), array({14.28285686, 39.7617907})) - .item()); - CHECK(allclose(norm(arr_three_d, "inf", {0, 1}), array({36., 39., 42.})) + CHECK(allclose( + norm(x, "f", std::vector{1, 2}), + array({14.28285686, 39.7617907})) .item()); - CHECK(allclose(norm(arr_three_d, "inf", {1, 2}), array({21., 48.})) + CHECK(allclose( + norm(x, "f", std::vector{2, 1}), + array({14.28285686, 39.7617907})) .item()); - CHECK(allclose(norm(arr_three_d, "-inf", {0, 1}), array({9., 12., 15.})) - .item()); - CHECK(allclose(norm(arr_three_d, "-inf", {1, 2}), array({3., 30.})) - .item()); - CHECK(allclose(norm(arr_three_d, "-inf", {1, 0}), array({9., 11., 13.})) - .item()); - CHECK(allclose(norm(arr_three_d, "-inf", {2, 1}), array({9., 36.})) - .item()); -} \ No newline at end of file +} From 49e3e99da350594af2a92103ee1fb4dd5d2179df Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 26 Dec 2023 10:58:02 -0800 Subject: [PATCH 34/37] remove unused includes --- python/src/linalg.cpp | 9 --------- python/src/overloaded.h | 8 -------- 2 files changed, 17 deletions(-) delete mode 100644 python/src/overloaded.h diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp index 7bd186d51..c193060db 100644 --- a/python/src/linalg.cpp +++ b/python/src/linalg.cpp @@ -1,22 +1,13 @@ - // Copyright © 2023 Apple Inc. -#include -#include -#include -#include #include -#include #include #include #include "mlx/linalg.h" -#include "mlx/ops.h" -#include "mlx/utils.h" #include "python/src/load.h" -#include "python/src/overloaded.h" #include "python/src/utils.h" namespace py = pybind11; diff --git a/python/src/overloaded.h b/python/src/overloaded.h deleted file mode 100644 index 204a3178a..000000000 --- a/python/src/overloaded.h +++ /dev/null @@ -1,8 +0,0 @@ -// Copyright © 2023 Apple Inc. - -template -struct overloaded : Ts... { - using Ts::operator()...; -}; -template -overloaded(Ts...) -> overloaded; \ No newline at end of file From 67e319488cefd339a9c7e74b001672a90485e13b Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 26 Dec 2023 14:47:56 -0800 Subject: [PATCH 35/37] fix python tests --- mlx/linalg.cpp | 29 ++++++++----- mlx/linalg.h | 17 ++++---- python/src/linalg.cpp | 8 ++-- python/tests/test_linalg.py | 84 +++++++++++++++---------------------- 4 files changed, 65 insertions(+), 73 deletions(-) diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index 61c9e8537..9cce6cabb 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -4,9 +4,7 @@ #include #include -#include "mlx/array.h" #include "mlx/linalg.h" -#include "mlx/ops.h" namespace mlx::core::linalg { @@ -48,25 +46,36 @@ inline array matrix_norm( auto dtype = at_least_float(a.dtype()); auto row_axis = axis[0]; auto col_axis = axis[1]; - if (!keepdims && col_axis > row_axis && col_axis > 0) { - col_axis -= 1; - } if (ord == -1.0) { + col_axis -= (!keepdims && col_axis > row_axis && col_axis > 0); return astype( min(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s), dtype, s); } else if (ord == 1.0) { + col_axis -= (!keepdims && col_axis > row_axis && col_axis > 0); return astype( max(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s), dtype, s); + } else if (ord == std::numeric_limits::infinity()) { + row_axis -= (!keepdims && row_axis > col_axis && row_axis > 0); + return astype( + max(sum(abs(a, s), col_axis, keepdims, s), row_axis, keepdims, s), + dtype, + s); + } else if (ord == -std::numeric_limits::infinity()) { + row_axis -= (!keepdims && row_axis > col_axis && row_axis > 0); + return astype( + min(sum(abs(a, s), col_axis, keepdims, s), row_axis, keepdims, s), + dtype, + s); } else if (ord == 2.0 || ord == -2.0) { throw std::runtime_error( "[linalg::norm] Singular value norms are not implemented."); } else { std::ostringstream msg; - msg << "[linalg::norm] Invalid ord value " << ord << " for matrix norm"; + msg << "[linalg::norm] Invalid ord " << ord << " for matrix norm."; throw std::invalid_argument(msg.str()); } } @@ -78,13 +87,13 @@ inline array matrix_norm( bool keepdims, StreamOrDevice s) { if (ord == "f" || ord == "fro") { - return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s)); + return sqrt(sum(square(a, s), axis, keepdims, s), s); } else if (ord == "nuc") { throw std::runtime_error( "[linalg::norm] Nuclear norm not yet implemented."); } else { std::ostringstream msg; - msg << "[linalg::norm] Invalid ord value '" << ord << "' for matrix norm"; + msg << "[linalg::norm] Invalid ord value '" << ord << "' for matrix norm."; throw std::invalid_argument(msg.str()); } } @@ -100,7 +109,7 @@ array norm( if (axis.value().size() > 2) { throw std::invalid_argument( - "[linalg::norm] Received too many axes for norm"); + "[linalg::norm] Received too many axes for norm."); } return sqrt(sum(square(a, s), axis.value(), keepdims, s), s); } @@ -124,7 +133,7 @@ array norm( return matrix_norm(a, ord, ax, keepdims, s); } else { throw std::invalid_argument( - "[linalg::norm] Received too many axes for norm"); + "[linalg::norm] Received too many axes for norm."); } } diff --git a/mlx/linalg.h b/mlx/linalg.h index bf3b5e78c..80e484eb5 100644 --- a/mlx/linalg.h +++ b/mlx/linalg.h @@ -4,21 +4,22 @@ #include -#include "array.h" -#include "device.h" -#include "ops.h" -#include "stream.h" +#include "mlx/array.h" +#include "mlx/device.h" +#include "mlx/ops.h" +#include "mlx/stream.h" namespace mlx::core::linalg { -/* +/** * Compute vector or matrix norms. * * - If axis and ord are both unspecified, computes the 2-norm of flatten(x). * - If axis is not provided but ord is, then x must be either 1D or 2D. - * - If axis is provided, but ord is not, then the 2-norm is computed along the - * given axes. At most 2 axes can be specified. - * - If both axis and ord are provided, then the corresponding matrix of vector + * - If axis is provided, but ord is not, then the 2-norm (or Frobenius norm + * for matrices) is computed along the given axes. At most 2 axes can be + * specified. + * - If both axis and ord are provided, then the corresponding matrix or vector * norm is computed. At most 2 axes can be specified. */ array norm( diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp index c193060db..ea5474a70 100644 --- a/python/src/linalg.cpp +++ b/python/src/linalg.cpp @@ -20,8 +20,8 @@ void init_linalg(py::module_& parent_module) { py::options options; options.disable_function_signatures(); - auto m = - parent_module.def_submodule("linalg", "mlx.core.linalg: Linear Algebra."); + auto m = parent_module.def_submodule( + "linalg", "mlx.core.linalg: linear algebra routines."); m.def( "norm", @@ -72,8 +72,8 @@ void init_linalg(py::module_& parent_module) { unless ``ord`` is ``None``. If both ``axis`` and ``ord`` are ``None``, the 2-norm of ``a.flatten`` will be returned. ord (scalar or str, optional): Order of the norm (see table under ``Notes``). - If ``None``, the 2-norm will be computed along the given ``axis``. - Default: ``None``. + If ``None``, the 2-norm (or Frobenius norm for matrices) will be computed + along the given ``axis``. Default: ``None``. axis (int or list(int), optional): If ``axis`` is an integer, it specifies the axis of ``a`` along which to compute the vector norms. If ``axis`` is a 2-tuple, it specifies the axes that hold 2-D matrices, and the matrix diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index ce1926de0..08a4510c8 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -11,74 +11,56 @@ import numpy as np class TestLinalg(mlx_tests.MLXTestCase): def test_norm(self): - vector_ords = [None, 0.5, 0, 1, 2, 3, -1, 1, float("inf"), -float("inf")] + vector_ords = [None, 0.5, 0, 1, 2, 3, -1, float("inf"), -float("inf")] matrix_ords = [None, "fro", -1, 1, float("inf"), -float("inf")] for shape in [(3,), (2, 3), (2, 3, 3)]: - x_mx = mx.arange(math.prod(shape)).reshape(shape) - x_np = np.arange(math.prod(shape)).reshape(shape) + x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape) + x_np = np.arange(1, math.prod(shape) + 1).reshape(shape) # Test when at least one axis is provided for num_axes in range(1, len(shape)): - for axis in itertools.combinations(range(len(shape)), num_axes): - if num_axes == 1: - ords = vector_ords - else: - ords = matrix_ords - for keepdims in [True, False]: - # Test axis provided, no ord provided - out_np = np.linalg.norm(x_np, axis=axis, keepdims=keepdims) - out_mx = mx.linalg.norm(x_mx, axis=axis, keepdims=keepdims) - assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) - # Test both ord and axis provided - for o in ords: - for keepdims in [True, False]: - if o: - out_np = np.linalg.norm( - x_np, ord=o, axis=axis, keepdims=keepdims - ) - out_mx = mx.linalg.norm( - x_mx, ord=o, axis=axis, keepdims=keepdims - ) - else: - out_np = np.linalg.norm( - x_np, axis=axis, keepdims=keepdims - ) - out_mx = mx.linalg.norm( - x_mx, axis=axis, keepdims=keepdims - ) - assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) - - # Test only axis provided - for shape in [(3,), (2, 3), (2, 3, 3)]: - x_mx = mx.arange(math.prod(shape)).reshape(shape) - x_np = np.arange(math.prod(shape)).reshape(shape) - - for num_axes in range(1, len(shape)): + if num_axes == 1: + ords = vector_ords + else: + ords = matrix_ords for axis in itertools.combinations(range(len(shape)), num_axes): for keepdims in [True, False]: - out_np = np.linalg.norm(x_np, axis=axis, keepdims=keepdims) - out_mx = mx.linalg.norm(x_mx, axis=axis, keepdims=keepdims) - assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) + for o in ords: + out_np = np.linalg.norm( + x_np, ord=o, axis=axis, keepdims=keepdims + ) + out_mx = mx.linalg.norm( + x_mx, ord=o, axis=axis, keepdims=keepdims + ) + with self.subTest( + shape=shape, ord=o, axis=axis, keepdims=keepdims + ): + self.assertTrue( + np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) + ) # Test only ord provided for shape in [(3,), (2, 3)]: - x_mx = mx.arange(math.prod(shape)).reshape(shape) - x_np = np.arange(math.prod(shape)).reshape(shape) + x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape) + x_np = np.arange(1, math.prod(shape) + 1).reshape(shape) for o in [None, 1, -1, float("inf"), -float("inf")]: for keepdims in [True, False]: out_np = np.linalg.norm(x_np, ord=o, keepdims=keepdims) out_mx = mx.linalg.norm(x_mx, ord=o, keepdims=keepdims) - assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) + with self.subTest(shape=shape, ord=o, keepdims=keepdims): + self.assertTrue( + np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) + ) # Test no ord and no axis provided for shape in [(3,), (2, 3), (2, 3, 3)]: - x_mx = mx.arange(math.prod(shape)).reshape(shape) - x_np = np.arange(math.prod(shape)).reshape(shape) - for o in [None, 1, -1, float("inf"), -float("inf")]: - for keepdims in [True, False]: - out_np = np.linalg.norm(x_np, keepdims=keepdims) - out_mx = mx.linalg.norm(x_mx, keepdims=keepdims) - assert np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) + x_mx = mx.arange(1, math.prod(shape) + 1).reshape(shape) + x_np = np.arange(1, math.prod(shape) + 1).reshape(shape) + for keepdims in [True, False]: + out_np = np.linalg.norm(x_np, keepdims=keepdims) + out_mx = mx.linalg.norm(x_mx, keepdims=keepdims) + with self.subTest(shape=shape, keepdims=keepdims): + self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)) if __name__ == "__main__": From e87c2d4af3c5450909241fabfb16c863ec45b911 Mon Sep 17 00:00:00 2001 From: Gabrijel Boduljak Date: Wed, 27 Dec 2023 02:12:33 +0100 Subject: [PATCH 36/37] fixed a bug with frobenius norm of a complex-valued matrix --- mlx/linalg.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index 9cce6cabb..e2c0e9f3f 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -4,6 +4,7 @@ #include #include +#include "mlx/dtype.h" #include "mlx/linalg.h" namespace mlx::core::linalg { @@ -87,7 +88,10 @@ inline array matrix_norm( bool keepdims, StreamOrDevice s) { if (ord == "f" || ord == "fro") { - return sqrt(sum(square(a, s), axis, keepdims, s), s); + if (is_complex(a.dtype())) + return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s), s); + else + return sqrt(sum(square(a, s), axis, keepdims, s), s); } else if (ord == "nuc") { throw std::runtime_error( "[linalg::norm] Nuclear norm not yet implemented."); From 673af67c92a11dfedbaacd5f588b404085ba5eb7 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 26 Dec 2023 19:40:20 -0800 Subject: [PATCH 37/37] complex for vector too --- mlx/linalg.cpp | 21 +++++++++++++++------ python/tests/test_linalg.py | 27 +++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 6 deletions(-) diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index e2c0e9f3f..7e7264e3f 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -13,6 +13,18 @@ Dtype at_least_float(const Dtype& d) { return is_floating_point(d) ? d : promote_types(d, float32); } +inline array l2_norm( + const array& a, + const std::vector& axis, + bool keepdims, + StreamOrDevice s) { + if (is_complex(a.dtype())) { + return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s), s); + } else { + return sqrt(sum(square(a, s), axis, keepdims, s), s); + } +} + inline array vector_norm( const array& a, const double ord, @@ -25,7 +37,7 @@ inline array vector_norm( } else if (ord == 1.0) { return astype(sum(abs(a, s), axis, keepdims, s), dtype, s); } else if (ord == 2.0) { - return sqrt(sum(square(a, s), axis, keepdims, s), s); + return l2_norm(a, axis, keepdims, s); } else if (ord == std::numeric_limits::infinity()) { return astype(max(abs(a, s), axis, keepdims, s), dtype, s); } else if (ord == -std::numeric_limits::infinity()) { @@ -88,10 +100,7 @@ inline array matrix_norm( bool keepdims, StreamOrDevice s) { if (ord == "f" || ord == "fro") { - if (is_complex(a.dtype())) - return sqrt(sum(abs(a, s) * abs(a, s), axis, keepdims, s), s); - else - return sqrt(sum(square(a, s), axis, keepdims, s), s); + return l2_norm(a, axis, keepdims, s); } else if (ord == "nuc") { throw std::runtime_error( "[linalg::norm] Nuclear norm not yet implemented."); @@ -115,7 +124,7 @@ array norm( throw std::invalid_argument( "[linalg::norm] Received too many axes for norm."); } - return sqrt(sum(square(a, s), axis.value(), keepdims, s), s); + return l2_norm(a, axis.value(), keepdims, s); } array norm( diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index 08a4510c8..ac86c1e11 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -62,6 +62,33 @@ class TestLinalg(mlx_tests.MLXTestCase): with self.subTest(shape=shape, keepdims=keepdims): self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)) + def test_complex_norm(self): + for shape in [(3,), (2, 3), (2, 3, 3)]: + x_np = np.random.uniform(size=shape).astype( + np.float32 + ) + 1j * np.random.uniform(size=shape).astype(np.float32) + x_mx = mx.array(x_np) + out_np = np.linalg.norm(x_np) + out_mx = mx.linalg.norm(x_mx) + with self.subTest(shape=shape): + self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)) + for num_axes in range(1, len(shape)): + for axis in itertools.combinations(range(len(shape)), num_axes): + out_np = np.linalg.norm(x_np, axis=axis) + out_mx = mx.linalg.norm(x_mx, axis=axis) + with self.subTest(shape=shape, axis=axis): + self.assertTrue( + np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6) + ) + + x_np = np.random.uniform(size=(4, 4)).astype( + np.float32 + ) + 1j * np.random.uniform(size=(4, 4)).astype(np.float32) + x_mx = mx.array(x_np) + out_np = np.linalg.norm(x_np, ord="fro") + out_mx = mx.linalg.norm(x_mx, ord="fro") + self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)) + if __name__ == "__main__": unittest.main()