mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-22 13:07:51 +08:00
removed unused imports
This commit is contained in:
parent
bbfe042a2b
commit
4bae4a8239
@ -1,7 +1,5 @@
|
|||||||
// Copyright © 2023 Apple Inc.
|
// Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
#include <iostream>
|
|
||||||
#include <set>
|
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
@ -9,7 +7,7 @@
|
|||||||
#include "mlx/array.h"
|
#include "mlx/array.h"
|
||||||
#include "mlx/linalg.h"
|
#include "mlx/linalg.h"
|
||||||
#include "mlx/ops.h"
|
#include "mlx/ops.h"
|
||||||
#include "utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
namespace mlx::core::linalg {
|
namespace mlx::core::linalg {
|
||||||
|
|
||||||
@ -41,7 +39,7 @@ inline array vector_norm(
|
|||||||
return max(abs(a, s), axis, keepdims, s);
|
return max(abs(a, s), axis, keepdims, s);
|
||||||
else if (ord == "-inf")
|
else if (ord == "-inf")
|
||||||
return min(abs(a, s), axis, keepdims, s);
|
return min(abs(a, s), axis, keepdims, s);
|
||||||
std::stringstream error_stream;
|
std::ostringstream error_stream;
|
||||||
error_stream << "Invalid ord value " << ord;
|
error_stream << "Invalid ord value " << ord;
|
||||||
throw std::invalid_argument(error_stream.str());
|
throw std::invalid_argument(error_stream.str());
|
||||||
}
|
}
|
||||||
@ -62,7 +60,7 @@ inline array matrix_norm(
|
|||||||
return max(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s);
|
return max(sum(abs(a, s), row_axis, keepdims, s), col_axis, keepdims, s);
|
||||||
if (ord == 2.0 || ord == -2.0)
|
if (ord == 2.0 || ord == -2.0)
|
||||||
throw std::logic_error("Singular value norms are not implemented.");
|
throw std::logic_error("Singular value norms are not implemented.");
|
||||||
std::stringstream error_stream;
|
std::ostringstream error_stream;
|
||||||
error_stream << "Invalid ord value " << ord << " for matrix norm";
|
error_stream << "Invalid ord value " << ord << " for matrix norm";
|
||||||
throw std::invalid_argument(error_stream.str());
|
throw std::invalid_argument(error_stream.str());
|
||||||
}
|
}
|
||||||
@ -81,7 +79,7 @@ inline array matrix_norm(
|
|||||||
return matrix_norm(a, -1.0, {axis[1], axis[0]}, keepdims, s);
|
return matrix_norm(a, -1.0, {axis[1], axis[0]}, keepdims, s);
|
||||||
if (ord == "nuc")
|
if (ord == "nuc")
|
||||||
throw std::logic_error("Nuclear norm is not implemented.");
|
throw std::logic_error("Nuclear norm is not implemented.");
|
||||||
std::stringstream error_stream;
|
std::ostringstream error_stream;
|
||||||
error_stream << "Invalid ord value " << ord << " for matrix norm";
|
error_stream << "Invalid ord value " << ord << " for matrix norm";
|
||||||
throw std::invalid_argument(error_stream.str());
|
throw std::invalid_argument(error_stream.str());
|
||||||
}
|
}
|
||||||
@ -101,7 +99,7 @@ array norm(
|
|||||||
s),
|
s),
|
||||||
s);
|
s);
|
||||||
|
|
||||||
std::stringstream error_stream;
|
std::ostringstream error_stream;
|
||||||
error_stream << "Invalid axis values " << axis;
|
error_stream << "Invalid axis values " << axis;
|
||||||
throw std::invalid_argument(error_stream.str());
|
throw std::invalid_argument(error_stream.str());
|
||||||
}
|
}
|
||||||
@ -125,7 +123,7 @@ array norm(
|
|||||||
else if (num_axes == 2)
|
else if (num_axes == 2)
|
||||||
return matrix_norm(a, ord, ax, keepdims, s);
|
return matrix_norm(a, ord, ax, keepdims, s);
|
||||||
|
|
||||||
std::stringstream error_stream;
|
std::ostringstream error_stream;
|
||||||
error_stream << "Invalid axis values " << ax;
|
error_stream << "Invalid axis values " << ax;
|
||||||
throw std::invalid_argument(error_stream.str());
|
throw std::invalid_argument(error_stream.str());
|
||||||
}
|
}
|
||||||
@ -149,7 +147,7 @@ array norm(
|
|||||||
else if (num_axes == 2)
|
else if (num_axes == 2)
|
||||||
return matrix_norm(a, ord, ax, keepdims, s);
|
return matrix_norm(a, ord, ax, keepdims, s);
|
||||||
|
|
||||||
std::stringstream error_stream;
|
std::ostringstream error_stream;
|
||||||
error_stream << "Invalid axis values " << ax;
|
error_stream << "Invalid axis values " << ax;
|
||||||
throw std::invalid_argument(error_stream.str());
|
throw std::invalid_argument(error_stream.str());
|
||||||
}
|
}
|
||||||
|
@ -2,13 +2,10 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <variant>
|
|
||||||
|
|
||||||
#include "array.h"
|
#include "array.h"
|
||||||
#include "device.h"
|
#include "device.h"
|
||||||
#include "ops.h"
|
#include "ops.h"
|
||||||
#include "stream.h"
|
#include "stream.h"
|
||||||
#include "string.h"
|
|
||||||
|
|
||||||
namespace mlx::core::linalg {
|
namespace mlx::core::linalg {
|
||||||
array norm(
|
array norm(
|
||||||
|
@ -286,10 +286,4 @@ std::ostream& operator<<(std::ostream& os, const std::vector<size_t>& v) {
|
|||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int> get_shape_reducing_over_all_axes(int ndim) {
|
|
||||||
std::vector<int> shape(ndim);
|
|
||||||
std::iota(shape.begin(), shape.end(), 0);
|
|
||||||
return shape;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
Loading…
Reference in New Issue
Block a user