mlx/mlx/linalg.h
Gabrijel Boduljak cc9b2dc3c2 implemented vector_norm in cpp
added linalg to mlx
2023-12-26 19:40:34 -08:00

45 lines
1.1 KiB
C++

// Copyright © 2023 Apple Inc.
#pragma once
#include <variant>
#include "array.h"
#include "device.h"
#include "ops.h"
#include "stream.h"
#include "string.h"
namespace mlx::core::linalg {
template <class... Ts>
struct overloaded : Ts... {
using Ts::operator()...;
};
template <class... Ts>
overloaded(Ts...) -> overloaded<Ts...>;
/*
Computes a vector norm.
If axes = {}, x will be flattened before the norm is computed.
Otherwise, the norm is computed over axes and the other dimensions are
treated as batch dimensions.
*/
array vector_norm(
const array& a,
const std::variant<double, std::string>& ord = 2.0,
const std::vector<int>& axes = {},
bool keepdims = false,
StreamOrDevice s = {});
array vector_norm(
const array& a,
const std::variant<double, std::string>& ord = 2.0,
bool keepdims = false,
StreamOrDevice s = {});
array vector_norm(
const array& a,
const std::vector<int>& axes = {},
bool keepdims = false,
StreamOrDevice s = {});
array vector_norm(const array& a, bool keepdims = false, StreamOrDevice s = {});
} // namespace mlx::core::linalg