mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
catch stream errors earlier to avoid aborts (#1801)
This commit is contained in:
parent
28091aa1ff
commit
2235dee906
@ -10,6 +10,15 @@
|
|||||||
|
|
||||||
namespace mlx::core::linalg {
|
namespace mlx::core::linalg {
|
||||||
|
|
||||||
|
void check_cpu_stream(const StreamOrDevice& s, const std::string& prefix) {
|
||||||
|
if (to_stream(s).device == Device::gpu) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
prefix +
|
||||||
|
" This op is not yet supported on the GPU. "
|
||||||
|
"Explicitly pass a CPU stream to run it.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Dtype at_least_float(const Dtype& d) {
|
Dtype at_least_float(const Dtype& d) {
|
||||||
return issubdtype(d, inexact) ? d : promote_types(d, float32);
|
return issubdtype(d, inexact) ? d : promote_types(d, float32);
|
||||||
}
|
}
|
||||||
@ -174,6 +183,7 @@ array norm(
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::pair<array, array> qr(const array& a, StreamOrDevice s /* = {} */) {
|
std::pair<array, array> qr(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
|
check_cpu_stream(s, "[linalg::qr]");
|
||||||
if (a.dtype() != float32) {
|
if (a.dtype() != float32) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[linalg::qr] Arrays must type float32. Received array "
|
msg << "[linalg::qr] Arrays must type float32. Received array "
|
||||||
@ -201,6 +211,7 @@ std::pair<array, array> qr(const array& a, StreamOrDevice s /* = {} */) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<array> svd(const array& a, StreamOrDevice s /* = {} */) {
|
std::vector<array> svd(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
|
check_cpu_stream(s, "[linalg::svd]");
|
||||||
if (a.dtype() != float32) {
|
if (a.dtype() != float32) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[linalg::svd] Input array must have type float32. Received array "
|
msg << "[linalg::svd] Input array must have type float32. Received array "
|
||||||
@ -239,6 +250,7 @@ std::vector<array> svd(const array& a, StreamOrDevice s /* = {} */) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
array inv_impl(const array& a, bool tri, bool upper, StreamOrDevice s) {
|
array inv_impl(const array& a, bool tri, bool upper, StreamOrDevice s) {
|
||||||
|
check_cpu_stream(s, "[linalg::inv]");
|
||||||
if (a.dtype() != float32) {
|
if (a.dtype() != float32) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[linalg::inv] Arrays must type float32. Received array "
|
msg << "[linalg::inv] Arrays must type float32. Received array "
|
||||||
@ -279,6 +291,7 @@ array cholesky(
|
|||||||
const array& a,
|
const array& a,
|
||||||
bool upper /* = false */,
|
bool upper /* = false */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
|
check_cpu_stream(s, "[linalg::cholesky]");
|
||||||
if (a.dtype() != float32) {
|
if (a.dtype() != float32) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[linalg::cholesky] Arrays must type float32. Received array "
|
msg << "[linalg::cholesky] Arrays must type float32. Received array "
|
||||||
@ -307,6 +320,7 @@ array cholesky(
|
|||||||
}
|
}
|
||||||
|
|
||||||
array pinv(const array& a, StreamOrDevice s /* = {} */) {
|
array pinv(const array& a, StreamOrDevice s /* = {} */) {
|
||||||
|
check_cpu_stream(s, "[linalg::pinv]");
|
||||||
if (a.dtype() != float32) {
|
if (a.dtype() != float32) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[linalg::pinv] Arrays must type float32. Received array "
|
msg << "[linalg::pinv] Arrays must type float32. Received array "
|
||||||
@ -353,16 +367,17 @@ array cholesky_inv(
|
|||||||
const array& L,
|
const array& L,
|
||||||
bool upper /* = false */,
|
bool upper /* = false */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
|
check_cpu_stream(s, "[linalg::cholesky_inv]");
|
||||||
if (L.dtype() != float32) {
|
if (L.dtype() != float32) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[linalg::cholesky] Arrays must type float32. Received array "
|
msg << "[linalg::cholesky_inv] Arrays must type float32. Received array "
|
||||||
<< "with type " << L.dtype() << ".";
|
<< "with type " << L.dtype() << ".";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (L.ndim() < 2) {
|
if (L.ndim() < 2) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[linalg::cholesky] Arrays must have >= 2 dimensions. Received array "
|
msg << "[linalg::cholesky_inv] Arrays must have >= 2 dimensions. Received array "
|
||||||
"with "
|
"with "
|
||||||
<< L.ndim() << " dimensions.";
|
<< L.ndim() << " dimensions.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
@ -370,7 +385,7 @@ array cholesky_inv(
|
|||||||
|
|
||||||
if (L.shape(-1) != L.shape(-2)) {
|
if (L.shape(-1) != L.shape(-2)) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[linalg::cholesky] Cholesky inverse is only defined for square "
|
"[linalg::cholesky_inv] Cholesky inverse is only defined for square "
|
||||||
"matrices.");
|
"matrices.");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -454,7 +469,11 @@ array cross(
|
|||||||
return concatenate(outputs, axis, s);
|
return concatenate(outputs, axis, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
void validate_eigh(const array& a, const std::string fname) {
|
void validate_eigh(
|
||||||
|
const array& a,
|
||||||
|
const StreamOrDevice& stream,
|
||||||
|
const std::string fname) {
|
||||||
|
check_cpu_stream(stream, fname);
|
||||||
if (a.dtype() != float32) {
|
if (a.dtype() != float32) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << fname << " Arrays must have type float32. Received array "
|
msg << fname << " Arrays must have type float32. Received array "
|
||||||
@ -478,7 +497,7 @@ array eigvalsh(
|
|||||||
const array& a,
|
const array& a,
|
||||||
std::string UPLO /* = "L" */,
|
std::string UPLO /* = "L" */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
validate_eigh(a, "[linalg::eigvalsh]");
|
validate_eigh(a, s, "[linalg::eigvalsh]");
|
||||||
Shape out_shape(a.shape().begin(), a.shape().end() - 1);
|
Shape out_shape(a.shape().begin(), a.shape().end() - 1);
|
||||||
return array(
|
return array(
|
||||||
std::move(out_shape),
|
std::move(out_shape),
|
||||||
@ -491,7 +510,7 @@ std::pair<array, array> eigh(
|
|||||||
const array& a,
|
const array& a,
|
||||||
std::string UPLO /* = "L" */,
|
std::string UPLO /* = "L" */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
validate_eigh(a, "[linalg::eigh]");
|
validate_eigh(a, s, "[linalg::eigh]");
|
||||||
auto out = array::make_arrays(
|
auto out = array::make_arrays(
|
||||||
{Shape(a.shape().begin(), a.shape().end() - 1), a.shape()},
|
{Shape(a.shape().begin(), a.shape().end() - 1), a.shape()},
|
||||||
{a.dtype(), a.dtype()},
|
{a.dtype(), a.dtype()},
|
||||||
|
@ -385,7 +385,7 @@ def sparse(
|
|||||||
raise ValueError("Only tensors with 2 dimensions are supported")
|
raise ValueError("Only tensors with 2 dimensions are supported")
|
||||||
|
|
||||||
rows, cols = a.shape
|
rows, cols = a.shape
|
||||||
num_zeros = int(mx.ceil(sparsity * cols))
|
num_zeros = int(math.ceil(sparsity * cols))
|
||||||
|
|
||||||
order = mx.argsort(mx.random.uniform(shape=a.shape), axis=1)
|
order = mx.argsort(mx.random.uniform(shape=a.shape), axis=1)
|
||||||
a = mx.random.normal(shape=a.shape, scale=std, loc=mean, dtype=dtype)
|
a = mx.random.normal(shape=a.shape, scale=std, loc=mean, dtype=dtype)
|
||||||
|
@ -421,6 +421,9 @@ TEST_CASE("test random normal") {
|
|||||||
}
|
}
|
||||||
|
|
||||||
TEST_CASE("test random multivariate_normal") {
|
TEST_CASE("test random multivariate_normal") {
|
||||||
|
// Scope switch to the cpu for SVDs
|
||||||
|
StreamContext sc(Device::cpu);
|
||||||
|
|
||||||
{
|
{
|
||||||
auto mean = zeros({3});
|
auto mean = zeros({3});
|
||||||
auto cov = eye(3);
|
auto cov = eye(3);
|
||||||
|
Loading…
Reference in New Issue
Block a user