mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Double for lapack (#1904)
* double for lapack ops * add double support for lapack ops
This commit is contained in:
@@ -18,6 +18,14 @@ void check_cpu_stream(const StreamOrDevice& s, const std::string& prefix) {
|
||||
"Explicitly pass a CPU stream to run it.");
|
||||
}
|
||||
}
|
||||
void check_float(Dtype dtype, const std::string& prefix) {
|
||||
if (dtype != float32 && dtype != float64) {
|
||||
std::ostringstream msg;
|
||||
msg << prefix << " Arrays must have type float32 or float64. "
|
||||
<< "Received array with type " << dtype << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
Dtype at_least_float(const Dtype& d) {
|
||||
return issubdtype(d, inexact) ? d : promote_types(d, float32);
|
||||
@@ -184,12 +192,8 @@ array norm(
|
||||
|
||||
std::pair<array, array> qr(const array& a, StreamOrDevice s /* = {} */) {
|
||||
check_cpu_stream(s, "[linalg::qr]");
|
||||
if (a.dtype() != float32) {
|
||||
std::ostringstream msg;
|
||||
msg << "[linalg::qr] Arrays must type float32. Received array "
|
||||
<< "with type " << a.dtype() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
check_float(a.dtype(), "[linalg::qr]");
|
||||
|
||||
if (a.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
msg << "[linalg::qr] Arrays must have >= 2 dimensions. Received array "
|
||||
@@ -212,12 +216,8 @@ std::pair<array, array> qr(const array& a, StreamOrDevice s /* = {} */) {
|
||||
|
||||
std::vector<array> svd(const array& a, StreamOrDevice s /* = {} */) {
|
||||
check_cpu_stream(s, "[linalg::svd]");
|
||||
if (a.dtype() != float32) {
|
||||
std::ostringstream msg;
|
||||
msg << "[linalg::svd] Input array must have type float32. Received array "
|
||||
<< "with type " << a.dtype() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
check_float(a.dtype(), "[linalg::svd]");
|
||||
|
||||
if (a.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
msg << "[linalg::svd] Input array must have >= 2 dimensions. Received array "
|
||||
@@ -251,12 +251,8 @@ std::vector<array> svd(const array& a, StreamOrDevice s /* = {} */) {
|
||||
|
||||
array inv_impl(const array& a, bool tri, bool upper, StreamOrDevice s) {
|
||||
check_cpu_stream(s, "[linalg::inv]");
|
||||
if (a.dtype() != float32) {
|
||||
std::ostringstream msg;
|
||||
msg << "[linalg::inv] Arrays must type float32. Received array "
|
||||
<< "with type " << a.dtype() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
check_float(a.dtype(), "[linalg::inv]");
|
||||
|
||||
if (a.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
msg << "[linalg::inv] Arrays must have >= 2 dimensions. Received array "
|
||||
@@ -292,13 +288,7 @@ array cholesky(
|
||||
bool upper /* = false */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
check_cpu_stream(s, "[linalg::cholesky]");
|
||||
if (a.dtype() != float32) {
|
||||
std::ostringstream msg;
|
||||
msg << "[linalg::cholesky] Arrays must type float32. Received array "
|
||||
<< "with type " << a.dtype() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
check_float(a.dtype(), "[linalg::cholesky]");
|
||||
if (a.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
msg << "[linalg::cholesky] Arrays must have >= 2 dimensions. Received array "
|
||||
@@ -321,12 +311,8 @@ array cholesky(
|
||||
|
||||
array pinv(const array& a, StreamOrDevice s /* = {} */) {
|
||||
check_cpu_stream(s, "[linalg::pinv]");
|
||||
if (a.dtype() != float32) {
|
||||
std::ostringstream msg;
|
||||
msg << "[linalg::pinv] Arrays must type float32. Received array "
|
||||
<< "with type " << a.dtype() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
check_float(a.dtype(), "[linalg::pinv]");
|
||||
|
||||
if (a.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
msg << "[linalg::pinv] Arrays must have >= 2 dimensions. Received array "
|
||||
@@ -368,12 +354,7 @@ array cholesky_inv(
|
||||
bool upper /* = false */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
check_cpu_stream(s, "[linalg::cholesky_inv]");
|
||||
if (L.dtype() != float32) {
|
||||
std::ostringstream msg;
|
||||
msg << "[linalg::cholesky_inv] Arrays must type float32. Received array "
|
||||
<< "with type " << L.dtype() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
check_float(L.dtype(), "[linalg::cholesky_inv]");
|
||||
|
||||
if (L.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
@@ -474,12 +455,7 @@ void validate_eigh(
|
||||
const StreamOrDevice& stream,
|
||||
const std::string fname) {
|
||||
check_cpu_stream(stream, fname);
|
||||
if (a.dtype() != float32) {
|
||||
std::ostringstream msg;
|
||||
msg << fname << " Arrays must have type float32. Received array "
|
||||
<< "with type " << a.dtype() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
check_float(a.dtype(), fname);
|
||||
|
||||
if (a.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
@@ -524,12 +500,7 @@ void validate_lu(
|
||||
const StreamOrDevice& stream,
|
||||
const std::string& fname) {
|
||||
check_cpu_stream(stream, fname);
|
||||
if (a.dtype() != float32) {
|
||||
std::ostringstream msg;
|
||||
msg << fname << " Arrays must type float32. Received array "
|
||||
<< "with type " << a.dtype() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
check_float(a.dtype(), fname);
|
||||
|
||||
if (a.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
@@ -627,10 +598,12 @@ void validate_solve(
|
||||
}
|
||||
|
||||
auto out_type = promote_types(a.dtype(), b.dtype());
|
||||
if (out_type != float32) {
|
||||
if (out_type != float32 && out_type != float64) {
|
||||
std::ostringstream msg;
|
||||
msg << fname << " Input arrays must promote to float32. Received arrays "
|
||||
<< "with type " << a.dtype() << " and " << b.dtype() << ".";
|
||||
msg << fname
|
||||
<< " Input arrays must promote to float32 or float64. "
|
||||
" Received arrays with type "
|
||||
<< a.dtype() << " and " << b.dtype() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user