mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
removed unused imports
This commit is contained in:
parent
bbfe042a2b
commit
4bae4a8239
@ -1,7 +1,5 @@
|
||||
// Copyright © 2023 Apple Inc.
|
||||
|
||||
#include <iostream>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
@ -9,7 +7,7 @@
|
||||
#include "mlx/array.h"
|
||||
#include "mlx/linalg.h"
|
||||
#include "mlx/ops.h"
|
||||
#include "utils.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core::linalg {
|
||||
|
||||
@ -41,7 +39,7 @@ inline array vector_norm(
|
||||
return max(abs(a, s), axis, keepdims, s);
|
||||
else if (ord == "-inf")
|
||||
return min(abs(a, s), axis, keepdims, s);
|
||||
std::stringstream error_stream;
|
||||
std::ostringstream error_stream;
|
||||
error_stream << "Invalid ord value " << ord;
|
||||
throw std::invalid_argument(error_stream.str());
|
||||
}
|
||||
@ -62,7 +60,7 @@ inline array matrix_norm(
|
||||
return max(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s);
|
||||
if (ord == 2.0 || ord == -2.0)
|
||||
throw std::logic_error("Singular value norms are not implemented.");
|
||||
std::stringstream error_stream;
|
||||
std::ostringstream error_stream;
|
||||
error_stream << "Invalid ord value " << ord << " for matrix norm";
|
||||
throw std::invalid_argument(error_stream.str());
|
||||
}
|
||||
@ -81,7 +79,7 @@ inline array matrix_norm(
|
||||
return matrix_norm(a, -1.0, {axis[1], axis[0]}, keepdims, s);
|
||||
if (ord == "nuc")
|
||||
throw std::logic_error("Nuclear norm is not implemented.");
|
||||
std::stringstream error_stream;
|
||||
std::ostringstream error_stream;
|
||||
error_stream << "Invalid ord value " << ord << " for matrix norm";
|
||||
throw std::invalid_argument(error_stream.str());
|
||||
}
|
||||
@ -101,7 +99,7 @@ array norm(
|
||||
s),
|
||||
s);
|
||||
|
||||
std::stringstream error_stream;
|
||||
std::ostringstream error_stream;
|
||||
error_stream << "Invalid axis values " << axis;
|
||||
throw std::invalid_argument(error_stream.str());
|
||||
}
|
||||
@ -125,7 +123,7 @@ array norm(
|
||||
else if (num_axes == 2)
|
||||
return matrix_norm(a, ord, ax, keepdims, s);
|
||||
|
||||
std::stringstream error_stream;
|
||||
std::ostringstream error_stream;
|
||||
error_stream << "Invalid axis values " << ax;
|
||||
throw std::invalid_argument(error_stream.str());
|
||||
}
|
||||
@ -149,7 +147,7 @@ array norm(
|
||||
else if (num_axes == 2)
|
||||
return matrix_norm(a, ord, ax, keepdims, s);
|
||||
|
||||
std::stringstream error_stream;
|
||||
std::ostringstream error_stream;
|
||||
error_stream << "Invalid axis values " << ax;
|
||||
throw std::invalid_argument(error_stream.str());
|
||||
}
|
||||
|
@ -2,13 +2,10 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <variant>
|
||||
|
||||
#include "array.h"
|
||||
#include "device.h"
|
||||
#include "ops.h"
|
||||
#include "stream.h"
|
||||
#include "string.h"
|
||||
|
||||
namespace mlx::core::linalg {
|
||||
array norm(
|
||||
|
@ -286,10 +286,4 @@ std::ostream& operator<<(std::ostream& os, const std::vector<size_t>& v) {
|
||||
return os;
|
||||
}
|
||||
|
||||
std::vector<int> get_shape_reducing_over_all_axes(int ndim) {
|
||||
std::vector<int> shape(ndim);
|
||||
std::iota(shape.begin(), shape.end(), 0);
|
||||
return shape;
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
Loading…
Reference in New Issue
Block a user