From 4bae4a8239cc64a8be7e038d3e096253c03708ab Mon Sep 17 00:00:00 2001 From: Gabrijel Boduljak Date: Sun, 24 Dec 2023 04:40:44 +0100 Subject: [PATCH] removed unused imports --- mlx/linalg.cpp | 16 +++++++--------- mlx/linalg.h | 3 --- mlx/utils.cpp | 6 ------ 3 files changed, 7 insertions(+), 18 deletions(-) diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index f541c6214..33fa92083 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -1,7 +1,5 @@ // Copyright © 2023 Apple Inc. -#include -#include #include #include #include @@ -9,7 +7,7 @@ #include "mlx/array.h" #include "mlx/linalg.h" #include "mlx/ops.h" -#include "utils.h" +#include "mlx/utils.h" namespace mlx::core::linalg { @@ -41,7 +39,7 @@ inline array vector_norm( return max(abs(a, s), axis, keepdims, s); else if (ord == "-inf") return min(abs(a, s), axis, keepdims, s); - std::stringstream error_stream; + std::ostringstream error_stream; error_stream << "Invalid ord value " << ord; throw std::invalid_argument(error_stream.str()); } @@ -62,7 +60,7 @@ inline array matrix_norm( return max(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s); if (ord == 2.0 || ord == -2.0) throw std::logic_error("Singular value norms are not implemented."); - std::stringstream error_stream; + std::ostringstream error_stream; error_stream << "Invalid ord value " << ord << " for matrix norm"; throw std::invalid_argument(error_stream.str()); } @@ -81,7 +79,7 @@ inline array matrix_norm( return matrix_norm(a, -1.0, {axis[1], axis[0]}, keepdims, s); if (ord == "nuc") throw std::logic_error("Nuclear norm is not implemented."); - std::stringstream error_stream; + std::ostringstream error_stream; error_stream << "Invalid ord value " << ord << " for matrix norm"; throw std::invalid_argument(error_stream.str()); } @@ -101,7 +99,7 @@ array norm( s), s); - std::stringstream error_stream; + std::ostringstream error_stream; error_stream << "Invalid axis values " << axis; throw std::invalid_argument(error_stream.str()); } @@ -125,7 +123,7 @@ array norm( else if (num_axes == 2) return matrix_norm(a, ord, ax, keepdims, s); - std::stringstream error_stream; + std::ostringstream error_stream; error_stream << "Invalid axis values " << ax; throw std::invalid_argument(error_stream.str()); } @@ -149,7 +147,7 @@ array norm( else if (num_axes == 2) return matrix_norm(a, ord, ax, keepdims, s); - std::stringstream error_stream; + std::ostringstream error_stream; error_stream << "Invalid axis values " << ax; throw std::invalid_argument(error_stream.str()); } diff --git a/mlx/linalg.h b/mlx/linalg.h index 690df343c..d77ada477 100644 --- a/mlx/linalg.h +++ b/mlx/linalg.h @@ -2,13 +2,10 @@ #pragma once -#include - #include "array.h" #include "device.h" #include "ops.h" #include "stream.h" -#include "string.h" namespace mlx::core::linalg { array norm( diff --git a/mlx/utils.cpp b/mlx/utils.cpp index ddcb41ba8..932217ad4 100644 --- a/mlx/utils.cpp +++ b/mlx/utils.cpp @@ -286,10 +286,4 @@ std::ostream& operator<<(std::ostream& os, const std::vector& v) { return os; } -std::vector get_shape_reducing_over_all_axes(int ndim) { - std::vector shape(ndim); - std::iota(shape.begin(), shape.end(), 0); - return shape; -} - } // namespace mlx::core