removed unused imports

This commit is contained in:
Gabrijel Boduljak 2023-12-24 04:40:44 +01:00 committed by Awni Hannun
parent bbfe042a2b
commit 4bae4a8239
3 changed files with 7 additions and 18 deletions

View File

@ -1,7 +1,5 @@
// Copyright © 2023 Apple Inc.
#include <iostream>
#include <set>
#include <sstream>
#include <string>
#include <vector>
@ -9,7 +7,7 @@
#include "mlx/array.h"
#include "mlx/linalg.h"
#include "mlx/ops.h"
#include "utils.h"
#include "mlx/utils.h"
namespace mlx::core::linalg {
@ -41,7 +39,7 @@ inline array vector_norm(
return max(abs(a, s), axis, keepdims, s);
else if (ord == "-inf")
return min(abs(a, s), axis, keepdims, s);
std::stringstream error_stream;
std::ostringstream error_stream;
error_stream << "Invalid ord value " << ord;
throw std::invalid_argument(error_stream.str());
}
@ -62,7 +60,7 @@ inline array matrix_norm(
return max(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s);
if (ord == 2.0 || ord == -2.0)
throw std::logic_error("Singular value norms are not implemented.");
std::stringstream error_stream;
std::ostringstream error_stream;
error_stream << "Invalid ord value " << ord << " for matrix norm";
throw std::invalid_argument(error_stream.str());
}
@ -81,7 +79,7 @@ inline array matrix_norm(
return matrix_norm(a, -1.0, {axis[1], axis[0]}, keepdims, s);
if (ord == "nuc")
throw std::logic_error("Nuclear norm is not implemented.");
std::stringstream error_stream;
std::ostringstream error_stream;
error_stream << "Invalid ord value " << ord << " for matrix norm";
throw std::invalid_argument(error_stream.str());
}
@ -101,7 +99,7 @@ array norm(
s),
s);
std::stringstream error_stream;
std::ostringstream error_stream;
error_stream << "Invalid axis values " << axis;
throw std::invalid_argument(error_stream.str());
}
@ -125,7 +123,7 @@ array norm(
else if (num_axes == 2)
return matrix_norm(a, ord, ax, keepdims, s);
std::stringstream error_stream;
std::ostringstream error_stream;
error_stream << "Invalid axis values " << ax;
throw std::invalid_argument(error_stream.str());
}
@ -149,7 +147,7 @@ array norm(
else if (num_axes == 2)
return matrix_norm(a, ord, ax, keepdims, s);
std::stringstream error_stream;
std::ostringstream error_stream;
error_stream << "Invalid axis values " << ax;
throw std::invalid_argument(error_stream.str());
}

View File

@ -2,13 +2,10 @@
#pragma once
#include <variant>
#include "array.h"
#include "device.h"
#include "ops.h"
#include "stream.h"
#include "string.h"
namespace mlx::core::linalg {
array norm(

View File

@ -286,10 +286,4 @@ std::ostream& operator<<(std::ostream& os, const std::vector<size_t>& v) {
return os;
}
std::vector<int> get_shape_reducing_over_all_axes(int ndim) {
std::vector<int> shape(ndim);
std::iota(shape.begin(), shape.end(), 0);
return shape;
}
} // namespace mlx::core