2023-12-27 04:42:04 +01:00
|
|
|
// Copyright © 2023 Apple Inc.
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
|
|
#include <optional>
|
|
|
|
|
|
|
|
|
|
#include "mlx/array.h"
|
|
|
|
|
#include "mlx/device.h"
|
|
|
|
|
#include "mlx/ops.h"
|
|
|
|
|
#include "mlx/stream.h"
|
|
|
|
|
|
|
|
|
|
namespace mlx::core::linalg {
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Compute vector or matrix norms.
|
|
|
|
|
*
|
|
|
|
|
* - If axis and ord are both unspecified, computes the 2-norm of flatten(x).
|
|
|
|
|
* - If axis is not provided but ord is, then x must be either 1D or 2D.
|
|
|
|
|
* - If axis is provided, but ord is not, then the 2-norm (or Frobenius norm
|
|
|
|
|
* for matrices) is computed along the given axes. At most 2 axes can be
|
|
|
|
|
* specified.
|
|
|
|
|
* - If both axis and ord are provided, then the corresponding matrix or vector
|
|
|
|
|
* norm is computed. At most 2 axes can be specified.
|
|
|
|
|
*/
|
|
|
|
|
array norm(
|
|
|
|
|
const array& a,
|
|
|
|
|
const double ord,
|
|
|
|
|
const std::optional<std::vector<int>>& axis = std::nullopt,
|
|
|
|
|
bool keepdims = false,
|
|
|
|
|
StreamOrDevice s = {});
|
|
|
|
|
inline array norm(
|
|
|
|
|
const array& a,
|
|
|
|
|
const double ord,
|
|
|
|
|
int axis,
|
|
|
|
|
bool keepdims = false,
|
|
|
|
|
StreamOrDevice s = {}) {
|
|
|
|
|
return norm(a, ord, std::vector<int>{axis}, keepdims, s);
|
|
|
|
|
}
|
|
|
|
|
array norm(
|
|
|
|
|
const array& a,
|
2024-04-09 11:22:00 -07:00
|
|
|
const std::string& ord,
|
2023-12-27 04:42:04 +01:00
|
|
|
const std::optional<std::vector<int>>& axis = std::nullopt,
|
|
|
|
|
bool keepdims = false,
|
|
|
|
|
StreamOrDevice s = {});
|
|
|
|
|
inline array norm(
|
|
|
|
|
const array& a,
|
2024-04-09 11:22:00 -07:00
|
|
|
const std::string& ord,
|
2023-12-27 04:42:04 +01:00
|
|
|
int axis,
|
|
|
|
|
bool keepdims = false,
|
|
|
|
|
StreamOrDevice s = {}) {
|
|
|
|
|
return norm(a, ord, std::vector<int>{axis}, keepdims, s);
|
|
|
|
|
}
|
|
|
|
|
array norm(
|
|
|
|
|
const array& a,
|
|
|
|
|
const std::optional<std::vector<int>>& axis = std::nullopt,
|
|
|
|
|
bool keepdims = false,
|
|
|
|
|
StreamOrDevice s = {});
|
|
|
|
|
inline array
|
|
|
|
|
norm(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {}) {
|
|
|
|
|
return norm(a, std::vector<int>{axis}, keepdims, s);
|
|
|
|
|
}
|
|
|
|
|
|
2024-01-26 09:27:31 -08:00
|
|
|
std::pair<array, array> qr(const array& a, StreamOrDevice s = {});
|
|
|
|
|
|
2024-03-12 20:30:11 +01:00
|
|
|
std::vector<array> svd(const array& a, StreamOrDevice s = {});
|
|
|
|
|
|
2024-03-15 14:34:36 +01:00
|
|
|
array inv(const array& a, StreamOrDevice s = {});
|
|
|
|
|
|
2023-12-27 04:42:04 +01:00
|
|
|
} // namespace mlx::core::linalg
|